Skip to content

feat: BF16 GEMM using CUTLASS backend for SM100#2070

Merged
aleozlx merged 18 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/cutlass_bf16_gemm_sm100
Jan 10, 2026
Merged

feat: BF16 GEMM using CUTLASS backend for SM100#2070
aleozlx merged 18 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/cutlass_bf16_gemm_sm100

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Nov 10, 2025

📌 Description

This issue was opened a little while ago (#1974) and I finally got a chance to tackle it. Feature request for BF16 GEMM. I decided to try and implement using CUTLASS backend. The issue poster was using B200 so I implemented for B200 (SM100) as well.

🔍 Related Issues

Feature request: #1974

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added high-performance BF16 matrix-multiply APIs (mm_bf16, bmm_bf16) with selectable backend, workspace management, autotuning support, and a runtime tactic query.
    • Integrated Cutlass-based BF16 GEMM runner, JIT generation for SM100 BF16 kernels, and public runtime entry points for native execution.
  • Documentation

    • Added BF16 GEMM docs and autosummary entries.
  • Tests

    • Added unit tests validating mm_bf16 and bmm_bf16 on supported GPUs.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 10, 2025

📝 Walkthrough

Walkthrough

Adds end-to-end SM100 BF16 GEMM: Cutlass-based CUDA runner and TVM FFI, SM100 Jinja codegen and kernel templates, public C++ runner headers, Python mm_bf16/bmm_bf16 APIs with JIT/autotune integration, tests, and docs; includes workspace sizing, tactic selection, and runtime validation.

Changes

Cohort / File(s) Summary
CUDA FFI & Runner Implementation
csrc/bf16_gemm_cutlass.cu
New CUDA source exposing TVM FFI bf16_gemm and bf16_gemm_tactic_num; runtime tactic selection, runGemm<T>, workspace sizing/allocation, 2D/batched handling, input/output validation, and explicit instantiations for CutlassBf16GemmRunner<__nv_bfloat16> and CutlassBf16GemmRunner<half>.
Jinja Instantiations
csrc/bf16_gemm_cutlass.jinja
New Jinja template emitting multiple SM100 BF16/half kernel instantiations (various CTA/cluster combos) for Cutlass BF16 path.
Public Headers — Runner Interface
include/flashinfer/gemm/bf16_gemm_cutlass.h
Adds CutlassBf16GemmRunnerInterface and templated CutlassBf16GemmRunner<T> declarations (gemm, getWorkspaceSize, getConfigs).
SM100 Kernel Templates & Launchers
include/flashinfer/gemm/bf16_gemm_cutlass_template.h, include/flashinfer/gemm/bf16_gemm_template_sm100.h
Adds SM100-specific BF16 kernel launcher/dispatch, SMTypeAdapter, _1SM/_2SM markers, genericBf16GemmKernelLauncherSm100, arch/cluster dispatch, workspace probing/memoization, and INSTANCE_BF16_GEMM_TEMPLATE_SM100 macro.
Python API & Backend Integration
flashinfer/gemm/gemm_base.py, flashinfer/gemm/__init__.py, flashinfer/__init__.py
Adds mm_bf16 and bmm_bf16 public APIs, BF16-specific validation/heuristics, JIT loader get_gemm_sm100_module_cutlass_bf16(), orchestrator bf16_gemm_sm100 integrating CUTLASS/TGV runners and autotuning, and package-level re-exports.
JIT Generation
flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/__init__.py
Adds gen_gemm_sm100_module_cutlass_bf16() to render bf16/half sources from the jinja, emit generated sources with -DENABLE_BF16 flags, and export the generator.
Tests & Documentation
tests/gemm/test_mm_bf16.py, tests/gemm/test_bmm_bf16.py, docs/api/gemm.rst
Adds parameterized tests validating mm_bf16/bmm_bf16 (cosine similarity checks) and inserts BF16 GEMM entries into API docs.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Py as Python caller (mm_bf16 / bmm_bf16)
    participant API as flashinfer.gemm API
    participant Orch as bf16_gemm_sm100 (orchestrator)
    participant JIT as JIT module / generator
    participant FFI as CUDA FFI (bf16_gemm)
    participant Runner as CutlassBf16GemmRunner
    participant Kernel as Cutlass kernel

    Py->>API: call mm_bf16/bmm_bf16(a,b,opts)
    API->>API: validate dtypes/shapes, prepare out & workspace
    API->>Orch: bf16_gemm_sm100(a,b,out,workspace,runner_names)
    Orch->>JIT: ensure/load SM100 BF16 module
    JIT-->>Orch: module + FFI bindings
    Orch->>FFI: call bf16_gemm(..., tactic)
    FFI->>Runner: getBf16GemmConfig(m,n,k,tactic)
    FFI->>Runner: runGemm<T>(...) -> compute workspace, call gemm
    Runner->>Kernel: launch Cutlass kernel on CUDA stream
    Kernel-->>Runner: kernel completes
    Runner-->>FFI: return status
    FFI-->>Orch: done
    Orch-->>API: completed
    API-->>Py: return tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • djmmoss
  • cyx-6
  • yongwww
  • nvmbreughe
  • aleozlx
  • bkryu
  • jiahanc

Poem

🐇 I hopped through JIT and Cutlass light,
BF16 tiles aligned in tidy rows,
Tactics chosen, workspace snug and tight,
Kernels leapt—cuda streams in flight,
A rabbit cheers: math fast as it goes!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: adding BF16 GEMM support using the CUTLASS backend for SM100 architecture.
Description check ✅ Passed The description addresses the template requirements: it explains the feature (BF16 GEMM implementation), references related issue #1974, confirms pre-commit checks and tests were completed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@raayandhar
Copy link
Copy Markdown
Contributor Author

raayandhar commented Nov 10, 2025

Currently there is an error about the second matrix being non-contiguous:
RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous
I am trying to work on it. However, I have limited access to B200s so it may be a bit difficult. I am also a newbie when it comes to CUTLASS, so if any experts could provide any feedback here, I would really appreciate. Especially concerning tile sizes, etc. Not sure what the best choices are (some seem to run into an error about SMEM space, which seems surprising to me?)

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from dd6216f to aaaee56 Compare November 10, 2025 04:09
@raayandhar raayandhar changed the title [FEAT] BF16 GEMM using CUTLASS backend for SM100 feat: BF16 GEMM using CUTLASS backend for SM100 Nov 10, 2025
@raayandhar raayandhar marked this pull request as ready for review November 10, 2025 04:11
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f5a06a4 and 8ce4cb4.

📒 Files selected for processing (13)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (497-520)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (2)
  • gen_gemm_sm100_module (240-316)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (250-313)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (250-313)
  • mm_bf16 (183-246)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-246)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
  • flashinfer (41-145)
  • gemm (42-95)
  • std (184-184)
  • std (185-185)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (497-520)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • CutlassBf16GemmRunnerInterface (29-41)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
csrc/tvm_ffi_utils.h (2)
  • get_stream (272-274)
  • encode_dlpack_dtype (29-31)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 GitHub Actions: pre-commit
flashinfer/__init__.py

[error] 88-88: mypy: Module "flashinfer.gemm" has no attribute "bmm_bf16".


[error] 90-90: mypy: Module "flashinfer.gemm" has no attribute "mm_bf16".

🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
docs/api/gemm.rst (2)

10-17: Documentation formatting is consistent and well-structured.

The BF16 GEMM subsection follows the established pattern of other GEMM sections in the file (consistent indentation, autosummary directive, toctree configuration). Placement at the beginning of the GEMM API documentation is logical and appropriate.


10-17: Documentation is complete and accurate.

The BF16 GEMM subsection correctly documents mm_bf16 and bmm_bf16—these are the only public-facing BF16 GEMM functions (verified by top-level exports in flashinfer/__init__.py). The bf16_gemm mentioned in the PR summary is an internal C++ binding and tuning identifier, not a public Python API.

Comment thread csrc/bf16_gemm_cutlass.cu
Comment thread flashinfer/__init__.py
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread include/flashinfer/gemm/bf16_gemm_template_sm100.h Outdated
Comment thread tests/gemm/test_bmm_bf16.py
Comment thread tests/gemm/test_mm_bf16.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

512-520: Materialize the transposed tensor before passing to the CUTLASS runner.

This is the root cause of the runtime error reported in the PR: b.transpose(-2, -1) returns a non-contiguous view, but the C++ binding requires contiguous input. The fix is to call .contiguous() on the transposed tensor.

Apply this fix:

                 a, b, out, workspace_buffer = inputs
                 module.bf16_gemm(
                     a,
-                    b.transpose(-2, -1),
+                    b.transpose(-2, -1).contiguous(),
                     out,
                     workspace_buffer,
                     tactic,
                 )

This issue was already identified in the previous review.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8ce4cb4 and 511d8e0.

📒 Files selected for processing (2)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (7)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
flashinfer/gemm/__init__.py (1)

1-38: LGTM! Public API exports are correctly wired.

The new BF16 GEMM functions (bmm_bf16 and mm_bf16) are properly imported from gemm_base and exposed through the module's __all__ list, making them available as part of the public API.

Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
@raayandhar raayandhar changed the title feat: BF16 GEMM using CUTLASS backend for SM100 feat: (wip) BF16 GEMM using CUTLASS backend for SM100 Nov 10, 2025
@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from 511d8e0 to fbe5723 Compare November 12, 2025 01:54
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

156-173: Restore workspace probe handling before launching kernels

CutlassBf16GemmRunner::getWorkspaceSize() calls this launcher with workspacePtr == nullptr to learn how big the buffer must be. Today we immediately throw because workspaceBytes == 0, so the probe reports 0 and the next real launch still fails with “insufficient workspace”. Please short‑circuit the probe and return the computed size instead of throwing.

   size_t workspace_size = gemm.get_workspace_size(arguments);
+  if (workspacePtr == nullptr) {
+    return workspace_size;
+  }
   if (workspace_size > workspaceBytes) {
flashinfer/gemm/gemm_base.py (4)

182-246: Validate BF16 MM inputs before dispatch

The CUTLASS path assumes 2‑D bf16 matrices on the same CUDA device with contiguous row‑major layout. Without the early guards we can accept the wrong dtype, mismatched shapes/devices, or a strided view and only fail deep inside the kernel (or produce garbage). Please restore the validation/contiguity fixes before touching the workspace.

-    if backend != "cutlass":
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
     if out_dtype not in (torch.bfloat16, torch.float16):
         raise ValueError("Only bf16 and fp16 outputs are supported.")
+
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch for matrix multiplication. a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+    if not a.is_contiguous():
+        a = a.contiguous()
+    if not b.is_contiguous():
+        b = b.contiguous()
+    if out is not None and not out.is_contiguous():
+        raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")

249-313: Do the same validation for BMM

The batched entry point has the same holes: wrong dtype, rank, device, or non‑contiguous slices go straight into CUTLASS and fail later (or worse, corrupt results). Please add the missing checks for 3‑D tensors, matching batch/K dims, same device, and enforce contiguity before launching.

-    if backend != "cutlass":
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
     if out_dtype not in (torch.bfloat16, torch.float16):
         raise ValueError("Only bf16 and fp16 outputs are supported.")
+
+    if a.ndim != 3 or b.ndim != 3:
+        raise ValueError(f"bmm_bf16 expects 3D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[0] != b.shape[0]:
+        raise ValueError(
+            f"Batch size mismatch. a.shape[0]={a.shape[0]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.shape[2] != b.shape[1]:
+        raise ValueError(
+            f"K dimension mismatch. a.shape[2]={a.shape[2]} must equal b.shape[1]={b.shape[1]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+    if not a.is_contiguous():
+        a = a.contiguous()
+    if not b.is_contiguous():
+        b = b.contiguous()
+    if out is not None and not out.is_contiguous():
+        raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")

512-520: Materialize column‑major B before calling the kernel

bf16_gemm still receives b.transpose(-2, -1) directly, which is a non‑contiguous view and reproduces the runtime error (“mat2 must be contiguous”). Please allocate the column‑major buffer before dispatching to CUTLASS.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )

590-592: Report the actual device when no runner is found

When a lives on a non‑default GPU, torch.device("cuda") queries device 0 and we raise “sm100” even if the tensor was on sm90. Use the tensor’s device so the error reflects reality.

-        major, minor = get_compute_capability(torch.device("cuda"))
+        major, minor = get_compute_capability(a.device)
🧹 Nitpick comments (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

54-91: Cluster shape dispatch with limited configuration support.

The function correctly dispatches based on cluster shape, with appropriate error handling for unsupported configurations. The limitation to only ClusterShape_1x1x1 aligns with the PR author's note about tile size and SMEM constraints during initial development.

Note: Line 66 has a break statement after return, which is unreachable but harmless.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 511d8e0 and fbe5723.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (6)
  • tests/gemm/test_mm_bf16.py
  • flashinfer/gemm/init.py
  • csrc/bf16_gemm_cutlass.jinja
  • tests/gemm/test_bmm_bf16.py
  • flashinfer/jit/gemm/init.py
  • csrc/bf16_gemm_cutlass.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (250-313)
  • mm_bf16 (183-246)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (5)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)

29-41: Well-designed interface for BF16 GEMM runner.

The abstract interface provides a clean contract with appropriate virtual methods for GEMM operations, workspace management, and configuration enumeration. The virtual destructor is correctly included for safe polymorphic deletion.


43-57: Template class declaration follows proper separation pattern.

The template class declaration correctly inherits from the interface and overrides all pure virtual methods. The separation of declaration (here) and definition (in the template header) is appropriate for template code.

include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)

19-34: Appropriate diagnostic pragmas for CUTLASS integration.

The GCC diagnostic pragmas correctly suppress strict-aliasing warnings around CUTLASS headers, which is necessary since CUTLASS may use type punning internally.


136-143: GEMM implementation correctly delegates to dispatch logic.

The implementation properly forwards all parameters to dispatchToArch with appropriate type casting.


186-210: Configuration enumeration with limited initial support.

The function correctly enumerates candidate configurations by combining tile configs and cluster shapes. The current limitation to a single tile configuration (CtaShape64x64x128B) and cluster shape (ClusterShape_1x1x1) aligns with the PR objectives and the author's noted constraints regarding SMEM space and limited B200 hardware access for testing.

As additional tile sizes and cluster shapes are validated on SM100 hardware, uncomment the relevant lines to expand the configuration space.


99-103: Verify the intentional A↔B and m↔n parameter swap is correct for the kernel expectations.

The parameter swap pattern dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(B, A, static_cast<T*>(D), n, m, k, ...) is applied consistently across all tile configurations in both bf16 and fp8 GEMM implementations. This is paired with explicit layout declarations: LayoutA = RowMajor and LayoutB = ColumnMajor.

While the consistency of this pattern across multiple files strongly suggests it is intentional for layout conversion, please confirm that this parameter reordering matches the actual kernel signature and expectations for dispatchGemmClusterShapeSm100.

Comment thread include/flashinfer/gemm/bf16_gemm_cutlass_template.h
Comment thread include/flashinfer/gemm/bf16_gemm_cutlass_template.h
Comment thread include/flashinfer/gemm/bf16_gemm_cutlass_template.h
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)

166-173: Improve MNK hash mixing to avoid collisions.

XORing h1 ^ h2 ^ h3 collapses different (m,n,k) permutations to the same bucket, so cached workspace sizes can be reused for incompatible shapes. Combine the hashes with a proper mixer instead.

   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };

175-183: Guard the static workspace cache with a mutex.

workspace_hashmap is mutated without synchronization; concurrent calls to getWorkspaceSize will race on find()/operator[]. Protect the cache with a lock.

-  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::mutex workspace_mutex;
 
   size_t workspace_size = 0;
-  if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
-    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
-    workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
-  } else {
-    workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
-  }
+  const MNK key = std::make_tuple(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it != workspace_hashmap.end()) {
+      return it->second;
+    }
+    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
+    workspace_hashmap.emplace(key, workspace_size);
+  }
   return workspace_size;
tests/gemm/test_mm_bf16.py (1)

14-21: Skip on CPU-only test environments.

get_compute_capability(torch.device("cuda")) raises when CUDA isn’t available, causing the entire suite to error out instead of skipping. Guard this with if not torch.cuda.is_available(): pytest.skip(...) before the capability query.

tests/gemm/test_bmm_bf16.py (1)

15-22: Gracefully skip when CUDA is unavailable.

Like the MM test, calling get_compute_capability(torch.device("cuda")) without checking torch.cuda.is_available() hard-fails on CPU-only setups. Add a skip guard before querying the device.

flashinfer/gemm/gemm_base.py (3)

217-240: Validate inputs before firing the kernel.

mm_bf16 still accepts tensors with wrong dtype, shape, or device, which the CUTLASS runner interprets incorrectly (e.g., passing fp16 data corrupts results). Add explicit checks for bf16 dtype, matching inner dimensions, and matching devices at the top of the function so misuse fails fast.

+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(
+            f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}."
+        )
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(
+            f"Inputs must be 2D matrices. Got a.ndim={a.ndim}, b.ndim={b.ndim}."
+        )
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(
+            f"Device mismatch: a.device={a.device}, b.device={b.device}."
+        )

288-307: Add basic sanity checks for batched inputs.

bmm_bf16 also needs dtype/shape/device validation; otherwise mismatched batch sizes or wrong K dimensions surface as low-level CUTLASS failures. Please mirror the checks from mm_bf16 for 3D tensors (batch, m, k) and (batch, k, n), ensuring matching batch/K dimensions and bf16 dtype.

+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(
+            f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}."
+        )
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(
+            f"Inputs must be 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}."
+        )
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(
+            f"Batch mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}."
+        )
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(
+            f"Device mismatch: A.device={A.device}, B.device={B.device}."
+        )

512-519: Make the transposed B operand contiguous.

The CUTLASS binding now enforces mat2.is_contiguous(). Transposing on the fly hands it a strided view and triggers the runtime error you reported. Materialize the column-major buffer before launching.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fbe5723 and 7f62bb0.

📒 Files selected for processing (4)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.563Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
  • flashinfer (41-145)
  • gemm (42-95)
  • std (184-184)
  • std (185-185)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
tests/gemm/test_mm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-246)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (250-313)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)

Comment thread tests/gemm/test_bmm_bf16.py
Comment thread tests/gemm/test_mm_bf16.py
@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from d2c8547 to 8a58e45 Compare November 16, 2025 22:36
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

156-170: Allow workspace probes to succeed when no buffer is provided.

CutlassBf16GemmRunner::getWorkspaceSizeImpl invokes this launcher with workspacePtr == nullptr and workspaceBytes == 0 to query the required size. The current code throws before returning the computed workspace_size, breaking workspace queries. Short-circuit when workspacePtr is nullptr to return the size without running the kernel.

   size_t workspace_size = gemm.get_workspace_size(arguments);
+  if (workspacePtr == nullptr) {
+    return workspace_size;
+  }
   if (workspace_size > workspaceBytes) {
     throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace");
   }
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)

166-173: Hash function prone to collisions.

The MNKHash function uses XOR to combine hash values (h1 ^ h2 ^ h3), which produces collisions for permutations of the same values. For example, (1, 2, 3) and (3, 2, 1) hash identically, potentially returning incorrect workspace sizes.

Use a proper hash combining algorithm:

   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      // Combine hashes properly to avoid collisions
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };

175-184: Critical: Data race on static workspace cache.

The static workspace_hashmap at Line 175 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe initialization of function-local statics, concurrent access via find() (Line 178) and operator[] (Lines 180, 182) creates data races if getWorkspaceSize is called from multiple threads.

Protect the map with a mutex:

+  static std::mutex workspace_mutex;
   static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;

   size_t workspace_size = 0;
+  std::lock_guard<std::mutex> lock(workspace_mutex);
   if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {

Alternatively, use std::shared_mutex with shared (read) and exclusive (write) locking for better concurrent read performance.

tests/gemm/test_bmm_bf16.py (1)

15-22: Guard test behind CUDA availability.

Calling get_compute_capability(torch.device("cuda")) without first checking torch.cuda.is_available() will raise an exception on non-CUDA systems instead of skipping gracefully.

Add an early CUDA check:

 def test_bmm_bf16(b, m, n, k, res_dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
     compute_capability = get_compute_capability(torch.device(device="cuda"))
flashinfer/gemm/gemm_base.py (3)

182-245: Add input validation for dtype, shape, and device consistency.

The function is missing essential input validation that could lead to cryptic errors downstream. Per previous review feedback, please add checks at the beginning of the function:

+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch for matrix multiplication. "
+            f"a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+
     if backend != "cutlass":

248-312: Add input validation for dtype, shape, and device consistency.

Similar to mm_bf16, this function lacks essential input validation. Per previous review feedback, please add checks at the beginning:

+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.")
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(f"Expected 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.")
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(
+            f"Batch size mismatch. A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}."
+        )
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"Shape mismatch for batched matrix multiplication. "
+            f"A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.")
+
     if backend != "cutlass":

511-519: Make the B operand contiguous before invoking the CUTLASS runner.

transpose(-2, -1) returns a non-contiguous view, which causes the runtime error you reported: "RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous". Per previous review feedback, materialize the column-major buffer before launching the kernel:

+                b_col_major = b.transpose(-2, -1).contiguous()
                 module.bf16_gemm(
                     a,
-                    b.transpose(-2, -1),
+                    b_col_major,
                     out,
                     workspace_buffer,
                     tactic,
                 )
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

187-211: Config enumeration is clean but limited to one configuration.

The getConfigs implementation only enumerates CtaShape64x64x128B and ClusterShape_1x1x1, reflecting the WIP status and SMEM constraints. The nested loop pattern is extensible for adding more configs once SMEM issues are resolved.

Would you like help generating a script to analyze SMEM usage across different tile configurations to understand which sizes are viable for SM100?

flashinfer/jit/gemm/core.py (1)

193-237: WIP tile configurations are appropriate for initial testing.

The implementation correctly follows the established pattern from FP8/FP4 modules. The single active tile configuration (64, 64, 128) is a reasonable conservative choice while debugging SMEM constraints on SM100 hardware, especially given your limited B200 access.

Optional style improvement (flagged by static analysis):

-    return gen_jit_spec(
-        "bf16_gemm_cutlass",
-        source_paths,
-        extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"],
-        extra_cflags=[
-            "-DFAST_BUILD",
-        ],
-    )
+    return gen_jit_spec(
+        "bf16_gemm_cutlass",
+        source_paths,
+        extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"],
+        extra_cflags=[
+            "-DFAST_BUILD",
+        ],
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2c8547 and 8a58e45.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/bf16_gemm_cutlass.cu
  • tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/bf16_gemm_cutlass.jinja
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • flashinfer/gemm/gemm_base.py
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • flashinfer/__init__.py
🧬 Code graph analysis (9)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-786)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (496-519)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (249-312)
  • mm_bf16 (183-245)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (249-312)
  • mm_bf16 (183-245)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
tests/gemm/test_bmm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (249-312)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py

217-217: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


229-231: Avoid specifying long messages outside the exception class

(TRY003)


233-235: Avoid specifying long messages outside the exception class

(TRY003)


237-239: Avoid specifying long messages outside the exception class

(TRY003)


283-283: Avoid specifying long messages outside the exception class

(TRY003)


285-285: Avoid specifying long messages outside the exception class

(TRY003)


296-298: Avoid specifying long messages outside the exception class

(TRY003)


300-302: Avoid specifying long messages outside the exception class

(TRY003)


304-306: Avoid specifying long messages outside the exception class

(TRY003)


499-499: Unused method argument: inputs

(ARG002)


500-500: Unused method argument: profile

(ARG002)


508-508: Unused method argument: do_preparation

(ARG002)


509-509: Unused method argument: kwargs

(ARG002)

flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (12)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)

1-62: LGTM! Clean interface/implementation pattern.

The abstract interface and templated concrete class follow best practices for extensibility. The separation of public getWorkspaceSize and private getWorkspaceSizeImpl suggests proper encapsulation of workspace size computation logic.

include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

46-151: LGTM! Standard CUTLASS GEMM setup.

The SMTypeAdapter specializations and launcher configuration follow CUTLASS patterns correctly. Regarding the comment on Line 147: setting fusion_args.alpha = 1.0f and fusion_args.beta = 0.0f is the standard way to configure a GEMM epilogue for D = A*B (no accumulation). This is the right approach.

include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)

136-143: LGTM! Clean forwarding to dispatcher.


145-160: Exception handling is appropriate for config probing.

The pattern of catching and ignoring std::runtime_error when probing workspace sizes is acceptable, as some configurations may legitimately fail due to SMEM constraints. The comment on Line 155 documents the rationale clearly.

Based on learnings


44-134: No changes needed—review comment is accurate.

Verification confirms the bf16 dispatcher is intentionally limited to CtaShape64x64x128B and ClusterShape_1x1x1 (lines 100–103 and 65–67), while other tile configs and cluster shapes remain commented out. This differs from the fp8 implementation, which enables multiple configurations, confirming the bf16 limitation is deliberate due to SMEM constraints as noted. The transpose pattern (swapping B, A and n, m at line 101–102) is correct for layout handling.

docs/api/gemm.rst (1)

10-18: LGTM! Documentation follows existing patterns.

The new BF16 GEMM section properly documents the mm_bf16 and bmm_bf16 entry points, following the same autosummary format as other GEMM types in this file.

flashinfer/__init__.py (1)

88-90: LGTM! BF16 GEMM exports are now available.

The imports of bmm_bf16 and mm_bf16 from the gemm module expose the new BF16 GEMM functionality at the top level. Past review comments indicate the necessary exports were added to flashinfer/gemm/__init__.py.

flashinfer/jit/gemm/__init__.py (1)

22-22: LGTM! JIT generator export follows existing patterns.

The gen_gemm_sm100_module_cutlass_bf16 import and export are consistent with other GEMM generators in this module.

Also applies to: 37-37

flashinfer/gemm/__init__.py (1)

2-2: LGTM! GEMM module exports are properly configured.

The bmm_bf16 and mm_bf16 imports from gemm_base and their inclusion in __all__ enable the top-level imports in flashinfer/__init__.py to work correctly.

Also applies to: 4-4, 25-25, 27-27

tests/gemm/test_bmm_bf16.py (1)

23-34: LGTM! Test logic is sound.

The test correctly creates BF16 inputs, computes a reference with torch.bmm, and validates the bmm_bf16 output using cosine similarity. The threshold of 0.99 is reasonable for BF16 precision.

csrc/bf16_gemm_cutlass.jinja (1)

1-27: LGTM! Clean template structure with conservative defaults.

The template correctly instantiates the SM100 BF16 GEMM kernel with a single-SM cluster configuration (1,1,1), which is appropriate for initial testing. The commented-out multi-SM cluster configurations provide clear guidance for future performance tuning once the basic implementation is validated.

flashinfer/gemm/gemm_base.py (1)

577-616: LGTM! AutoTuner integration follows established patterns.

The function correctly uses a.device for SM version checking and properly integrates with the AutoTuner for dynamic tactic selection. The tuning configuration appropriately profiles on the M dimension using power-of-2 bucketing, matching the pattern used in fp8_gemm_sm100.

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from dcbc17a to 28baee5 Compare November 17, 2025 05:10
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
tests/gemm/test_mm_bf16.py (2)

25-31: Use row‑major (k, n) weight and avoid .T to keep inputs contiguous.

Generate mat2 as (k, n), use it directly in both the reference and API call. This prevents passing a non‑contiguous transpose and matches the documented contract.

-    mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
-
-    reference = torch.mm(input, mat2.T)
+    mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)
+    reference = torch.mm(input, mat2)
 ...
-        mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype)
+        mm_bf16(input, mat2, out=out, out_dtype=res_dtype)

14-16: Skip on CPU-only to avoid hard failure.

Add CUDA-availability guard before calling get_compute_capability.

 def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype):
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
tests/gemm/test_bmm_bf16.py (1)

14-16: Skip on CPU-only to avoid hard failure.

Guard get_compute_capability(torch.device("cuda")) with a CUDA-availability check so the test skips instead of crashing on CPU-only runners.

 def test_bmm_bf16(b, m, n, k, res_dtype):
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
flashinfer/gemm/gemm_base.py (3)

182-205: Add essential input validation (dtype, shape, device).

Fail fast with clear errors to avoid cryptic backend failures.

 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@
 ) -> torch.Tensor:
@@
-    if backend != "cutlass":
+    # Basic validations
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")

259-283: Add essential input validation (batched dtype/shape/device).

Validate 3D inputs, batch, and K dims before launching the kernel.

 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@
 ) -> torch.Tensor:
@@
-    if backend != "cutlass":
+    # Basic validations
+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.")
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(f"bmm_bf16 expects 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.")
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(f"Batch size mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}.")
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.")
+
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")

533-541: Fix runtime error: make B contiguous before calling the CUTLASS binding.

b.transpose(-2, -1) is a non‑contiguous view; the C++ binding asserts contiguity (“mat2 must be contiguous”). Materialize column‑major B.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

157-175: Thread-safety and hash quality in workspace cache.

  • XOR-combining hashes collides easily.
  • The static workspace_hashmap is accessed unsafely across threads.

Harden both.

@@
-  struct MNKHash {
+  struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      // Robust hash combine to reduce collisions
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };
@@
-  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::mutex workspace_mutex;
+  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
@@
-  size_t workspace_size = 0;
-  if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
-    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
-    workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
-  } else {
-    workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
-  }
-  return workspace_size;
+  const MNK key = std::make_tuple(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it != workspace_hashmap.end()) {
+      return it->second;
+    }
+  }
+  // Compute outside lock to avoid blocking others; insert with lock.
+  size_t computed = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it == workspace_hashmap.end()) {
+      workspace_hashmap.emplace(key, computed);
+      return computed;
+    }
+    return it->second;
+  }

Also add the include near the top:

-#include <stdexcept>
+#include <stdexcept>
+#include <mutex>
🧹 Nitpick comments (2)
flashinfer/jit/gemm/core.py (1)

229-236: Minor: prefer list splat over concatenation (RUF005).

Use list unpacking for readability in extra_cuda_cflags.

-    return gen_jit_spec(
+    return gen_jit_spec(
         "bf16_gemm_cutlass",
         source_paths,
-        extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"],
+        extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"],
         extra_cflags=[
             "-DFAST_BUILD",
         ],
     )
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

101-104: Remove unused template parameter or use it.

genericBf16GemmKernelLauncherSm100 has template param arch but hardcodes ArchTag = cutlass::arch::Sm100. Either use arch or drop the parameter.

-  using ArchTag = cutlass::arch::Sm100;
+  using ArchTag = arch;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8a58e45 and 28baee5.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (5)
  • docs/api/gemm.rst
  • csrc/bf16_gemm_cutlass.cu
  • flashinfer/init.py
  • flashinfer/gemm/init.py
  • flashinfer/jit/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/bf16_gemm_cutlass.jinja
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • flashinfer/gemm/gemm_base.py
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (7)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-256)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (260-334)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-182)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (518-541)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-182)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (5)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
  • gemm (44-182)
  • _1SM (53-57)
  • _2SM (60-64)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py

232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

228-228: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Avoid specifying long messages outside the exception class

(TRY003)


240-242: Avoid specifying long messages outside the exception class

(TRY003)


244-246: Avoid specifying long messages outside the exception class

(TRY003)


248-250: Avoid specifying long messages outside the exception class

(TRY003)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


318-320: Avoid specifying long messages outside the exception class

(TRY003)


322-324: Avoid specifying long messages outside the exception class

(TRY003)


326-328: Avoid specifying long messages outside the exception class

(TRY003)


521-521: Unused method argument: inputs

(ARG002)


522-522: Unused method argument: profile

(ARG002)


530-530: Unused method argument: do_preparation

(ARG002)


531-531: Unused method argument: kwargs

(ARG002)

🔇 Additional comments (2)
csrc/bf16_gemm_cutlass.jinja (1)

17-26: Instantiation set looks good.

Coverage of cluster shapes for 1SM/2SM variants matches the SM100 launcher; no issues spotted.

include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

152-157: Good: explicit workspace probe path.

Early-return on null A/B/D ensures getWorkspaceSizeImpl can probe without needing a buffer. This unblocks tactic sizing. Based on learnings.

@raayandhar raayandhar changed the title feat: (wip) BF16 GEMM using CUTLASS backend for SM100 feat: BF16 GEMM using CUTLASS backend for SM100 Nov 17, 2025
@raayandhar
Copy link
Copy Markdown
Contributor Author

raayandhar commented Nov 17, 2025

Hi experts, I think this is now ready for review!
I had more trouble than I expected, even though there were already FP8 and FP4 implementations of CUTLASS GEMMs, and I learned a lot working on this, especially since it was my first time using/working with CUTLASS.

Right now we are passing all the tests that I wrote for this feature:

Test Results (click to expand)
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# pytest tests/gemm/test_bmm_bf16.py
======================================================== test session starts =========================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 32 items                                                                                                                   

tests/gemm/test_bmm_bf16.py ................................                                                                   [100%]

========================================================= 32 passed in 2.37s =========================================================
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# pytest tests/gemm/test_mm_bf16.py
======================================================== test session starts =========================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 90 items                                                                                                                   

tests/gemm/test_mm_bf16.py ..........................................................................................          [100%]

========================================================= 90 passed in 3.45s =========================================================

The original issue (#1974) was to see if CUTLASS backend GEMM for BF16 could do better at smaller batch sizes. Now, using linear_mm with the autotuning gives the following results (updated script here):

Benchmark Results (click to expand)
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# python benchmark_linear.py 
2025-11-17 05:53:37,584 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-17 05:53:37,859 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
batch=1
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         12.354560       0.010645     1357.98      1358.97      1.00x
2. torch.compile()                  12.349120       0.016171     1358.58      1359.57      1.00x
3. max-autotune ncg                 5.017280        0.021958     3343.89      3346.34      2.46x
4. TGV GEMM pdl=False               6.540480        0.010113     2565.14      2567.01      1.89x
5. TGV GEMM pdl=True                6.099520        0.015821     2750.58      2752.59      2.03x
6. MM BF16                          9.650880        0.010973     1738.41      1739.69      1.28x

batch=2
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         11.057280       0.015711     3034.60      1519.52      1.00x
2. torch.compile()                  11.058559       0.009181     3034.25      1519.35      1.00x
3. max-autotune ncg                 11.062080       0.015267     3033.28      1518.86      1.00x
4. TGV GEMM pdl=False               6.560000        0.019785     5115.00      2561.25      1.69x
5. TGV GEMM pdl=True                6.123200        0.017838     5479.88      2743.96      1.81x
6. MM BF16                          7.129600        0.013278     4706.36      2356.62      1.55x

batch=4
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         11.060160       0.012899     6067.62      1521.35      1.00x
2. torch.compile()                  11.062720       0.015620     6066.22      1521.00      1.00x
3. max-autotune ncg                 11.064000       0.015849     6065.52      1520.82      1.00x
4. TGV GEMM pdl=False               6.556480        0.015751     10235.50     2566.37      1.69x
5. TGV GEMM pdl=True                6.122880        0.021713     10960.34     2748.11      1.81x
6. MM BF16                          7.208640        0.012594     9309.50      2334.19      1.53x

batch=8
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         12.005440       0.011114     11179.74     1405.66      1.00x
2. torch.compile()                  11.999360       0.015240     11185.41     1406.37      1.00x
3. max-autotune ncg                 11.997440       0.015384     11187.20     1406.59      1.00x
4. TGV GEMM pdl=False               6.547520        0.011645     20499.02     2577.39      1.83x
5. TGV GEMM pdl=True                6.127040        0.017312     21905.80     2754.27      1.96x
6. MM BF16                          7.579840        0.028135     17707.20     2226.37      1.58x

batch=16
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.046400        0.018652     33360.94     2109.49      1.00x
2. torch.compile()                  8.064640        0.016166     33285.48     2104.72      1.00x
3. max-autotune ncg                 8.056320        0.021365     33319.86     2106.90      1.00x
4. TGV GEMM pdl=False               6.656000        0.014604     40329.85     2550.15      1.21x
5. TGV GEMM pdl=True                6.190080        0.013137     43365.43     2742.10      1.30x
6. MM BF16                          7.120960        0.015677     37696.53     2383.64      1.13x

batch=32
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.078080        0.018555     66460.21     2125.56      1.00x
2. torch.compile()                  8.084480        0.010366     66407.60     2123.88      1.00x
3. max-autotune ncg                 8.063040        0.016551     66584.18     2129.52      1.00x
4. TGV GEMM pdl=False               6.957120        0.017935     77168.56     2468.04      1.16x
5. TGV GEMM pdl=True                6.484800        0.011767     82789.12     2647.80      1.25x
6. MM BF16                          7.397760        0.004454     72572.09     2321.03      1.09x

batch=64
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.182080        0.045039     131230.92    2146.60      1.00x
2. torch.compile()                  8.167680        0.016335     131462.28    2150.38      1.00x
3. max-autotune ncg                 8.189760        0.012934     131107.85    2144.59      1.00x
4. TGV GEMM pdl=False               12.971520       0.028602     82776.87     1354.02      0.63x
5. TGV GEMM pdl=True                12.534400       0.012996     85663.60     1401.24      0.65x
6. MM BF16                          7.150080        0.020649     150172.00    2456.43      1.14x

(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# 

but the highlight is that at a larger batch size like batch=64, we are at ~7 microseconds while TGV is at ~12.5-13 microseconds, and original/torch.compile() is at ~8 microseconds. I'm a CUTLASS newbie, so maybe adding more tile sizes and cluster shapes / autotuning wider can get even better performance, since we are slightly worse elsewhere...

For reviewers:

  • I just picked tile sizes/cluster shapes that compiled from the FP8 implementations. Maybe we can get better performance with more tile sizes/cluster shapes? Especially targeting smaller batch sizes.
  • Let me know if different test coverage is needed in terms of sizes.
  • This is just SM100. I think this can largely be re-used for SM120. If you are happy with the changes here, I'd be happy to tackle SM120 as well. Also on this note, it may be worth to add a cuDNN backend option, if the performance might be better?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 28baee5 and d3a53cd.

📒 Files selected for processing (1)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
🧬 Code graph analysis (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)

183-190: LGTM! Macro enables flexible template instantiation.

The INSTANCE_BF16_GEMM_TEMPLATE_SM100 macro correctly provides explicit template instantiation control for different tile sizes, cluster shapes, and SM types. The parameter list matches the launcher's template signature, and the macro will be used by the JIT generator to instantiate specific configurations.


49-50: ****

The forward declarations in this file are not problematic duplicates requiring consolidation. The actual struct definitions of _1SM and _2SM are in include/flashinfer/gemm/bf16_gemm_cutlass_template.h (lines 44-45), while the SM100 template files provide independent forward declarations. This is the correct C++ pattern: the base template defines the types, and SM100 template files forward-declare them to specialize SMTypeAdapter<_1SM> and SMTypeAdapter<_2SM> without incurring unnecessary includes. This separation of concerns is appropriate and consistent across all GEMM implementations (bf16, fp8, fp4).

Likely an incorrect or invalid review comment.

Comment thread include/flashinfer/gemm/bf16_gemm_template_sm100.h
Comment thread tests/gemm/test_bmm_bf16.py
Comment thread include/flashinfer/gemm/bf16_gemm_template_sm100.h Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment thread csrc/bf16_gemm_cutlass.cu
throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace");
}

auto can_implement = gemm.can_implement(arguments);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any advantage to doing these safety checks this way instead of just using the CUTLASS_CHECK macro? I saw it done this way for FP8 and FP4, so I kept it this way. But just wondering because it seems the same?

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from 1387bed to a56d74b Compare November 19, 2025 00:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

166-176: Critical: Protect static workspace cache from concurrent access.

The static workspace_hashmap at line 166 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe static initialization, subsequent operations (find() at line 169, operator[] at lines 171 and 173) create data races when getWorkspaceSize is called from multiple threads, potentially causing:

  • Cache corruption
  • Iterator invalidation
  • Undefined behavior

The BF16 GEMM APIs can be called concurrently from multiple threads in typical inference workloads, making this a critical issue.

🔎 Add mutex protection for thread safety
+  static std::mutex workspace_mutex;
   static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;

   size_t workspace_size = 0;
+  std::lock_guard<std::mutex> lock(workspace_mutex);
   if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
     workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
     workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
   } else {
     workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
   }
   return workspace_size;

Alternatively, use std::shared_mutex with std::shared_lock for reads and std::unique_lock for writes to allow concurrent readers while still protecting writes.

flashinfer/gemm/gemm_base.py (1)

818-826: Correct the type annotation for bias parameter.

The bias parameter can be None (e.g., bmm_bf16 calls this with None at line 509), but the type annotation says torch.Tensor. This should be Optional[torch.Tensor] for correctness.

Proposed fix
 def bf16_gemm_sm100(
     a: torch.Tensor,
     b: torch.Tensor,
-    bias: torch.Tensor,
+    bias: Optional[torch.Tensor],
     pdl: bool,
     out: torch.Tensor,
     workspace_buffer: torch.Tensor,
     runner_names: List[str],
 ) -> None:
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

157-164: Consider a collision-resistant hash combiner.

The MNKHash function uses XOR to combine three hash values (h1 ^ h2 ^ h3), which can produce identical hashes for different permutations of the same values. For example, (m=1, n=2, k=3) and (m=3, n=2, k=1) would hash to the same value, potentially causing the cache to return incorrect workspace sizes.

While collisions here would only cause redundant recomputation (not correctness issues), a more robust hash combiner would improve cache efficiency.

🔎 Collision-resistant hash combiner
   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      // Combine hashes to avoid collisions from permutations
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };
flashinfer/gemm/gemm_base.py (1)

219-222: Clarify the TGV output dtype restriction error message.

The current error message is confusing. When out_dtype != torch.bfloat16, it says "You cannot provide an output dtype" which is misleading—the real constraint is that TGV only supports bfloat16 output. Consider rewording for clarity:

Proposed fix
     if out_dtype != torch.bfloat16:
         raise ValueError(
-            "You cannot provide an output dtype to the TGV backend. Use the CUTLASS backend instead."
+            "TGV backend only supports bfloat16 output dtype. Use the CUTLASS backend for fp16 output."
         )
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7796098 and e56ae28.

📒 Files selected for processing (13)
  • csrc/bf16_gemm_cutlass.cu
  • csrc/bf16_gemm_cutlass.jinja
  • docs/api/gemm.rst
  • flashinfer/__init__.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/__init__.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • tests/gemm/test_bmm_bf16.py
  • tests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tests/gemm/test_bmm_bf16.py
  • csrc/bf16_gemm_cutlass.jinja
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • flashinfer/gemm/init.py
  • tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • csrc/bf16_gemm_cutlass.cu
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • flashinfer/__init__.py
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
csrc/bf16_gemm_cutlass.cu (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • CutlassBf16GemmRunnerInterface (29-41)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
csrc/tvm_ffi_utils.h (2)
  • get_stream (294-296)
  • encode_dlpack_dtype (30-32)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (5)
  • gemm (44-181)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (441-510)
  • mm_bf16 (285-384)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (216-397)
  • gen_jit_spec (400-466)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
flashinfer/gemm/gemm_base.py (1)
flashinfer/utils.py (4)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
  • _get_cache_buf (206-217)
  • suitable_auto_backends (1076-1096)
🪛 Ruff (0.14.10)
flashinfer/jit/gemm/core.py

232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

187-187: Unused function argument: a

(ARG001)


188-188: Unused function argument: b

(ARG001)


189-189: Unused function argument: out

(ARG001)


193-193: Unused function argument: backend

(ARG001)


196-198: Avoid specifying long messages outside the exception class

(TRY003)


200-202: Avoid specifying long messages outside the exception class

(TRY003)


211-211: Unused function argument: a

(ARG001)


212-212: Unused function argument: b

(ARG001)


213-213: Unused function argument: out

(ARG001)


215-215: Unused function argument: bias

(ARG001)


216-216: Unused function argument: pdl

(ARG001)


217-217: Unused function argument: backend

(ARG001)


220-222: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Unused function argument: pdl

(ARG001)


231-231: Unused function argument: out

(ARG001)


232-232: Unused function argument: out_dtype

(ARG001)


233-233: Unused function argument: backend

(ARG001)


236-238: Avoid specifying long messages outside the exception class

(TRY003)


240-242: Avoid specifying long messages outside the exception class

(TRY003)


245-247: Avoid specifying long messages outside the exception class

(TRY003)


255-255: Unused function argument: b

(ARG001)


258-258: Unused function argument: out

(ARG001)


259-259: Unused function argument: out_dtype

(ARG001)


260-260: Unused function argument: backend

(ARG001)


355-357: Avoid specifying long messages outside the exception class

(TRY003)


359-361: Avoid specifying long messages outside the exception class

(TRY003)


363-365: Avoid specifying long messages outside the exception class

(TRY003)


389-389: Unused function argument: A

(ARG001)


390-390: Unused function argument: B

(ARG001)


391-391: Unused function argument: out

(ARG001)


393-393: Unused function argument: backend

(ARG001)


403-403: Unused function argument: out

(ARG001)


404-404: Unused function argument: out_dtype

(ARG001)


405-405: Unused function argument: backend

(ARG001)


408-410: Avoid specifying long messages outside the exception class

(TRY003)


412-414: Avoid specifying long messages outside the exception class

(TRY003)


421-421: Unused function argument: A

(ARG001)


422-422: Unused function argument: B

(ARG001)


423-423: Unused function argument: out

(ARG001)


424-424: Unused function argument: out_dtype

(ARG001)


425-425: Unused function argument: backend

(ARG001)


446-446: Unused function argument: backend

(ARG001)


494-496: Avoid specifying long messages outside the exception class

(TRY003)


498-500: Avoid specifying long messages outside the exception class

(TRY003)


502-504: Avoid specifying long messages outside the exception class

(TRY003)


770-770: Unused method argument: inputs

(ARG002)


771-771: Unused method argument: profile

(ARG002)


779-779: Unused method argument: do_preparation

(ARG002)


780-780: Unused method argument: kwargs

(ARG002)


1068-1070: Avoid specifying long messages outside the exception class

(TRY003)


1072-1074: Avoid specifying long messages outside the exception class

(TRY003)


1076-1078: Avoid specifying long messages outside the exception class

(TRY003)


1655-1658: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (24)
docs/api/gemm.rst (1)

10-18: LGTM! Documentation structure is correct.

The BF16 GEMM section is well-positioned and follows the established pattern for documenting GEMM operations in this module.

flashinfer/__init__.py (1)

88-91: LGTM! BF16 GEMM APIs properly exposed at package root.

The imports of bmm_bf16 and mm_bf16 correctly wire the root package API to the BF16 GEMM functionality, aligning with the implementation in flashinfer/gemm/gemm_base.py and its re-exports in flashinfer/gemm/__init__.py.

flashinfer/jit/gemm/__init__.py (2)

22-22: LGTM! BF16 Cutlass generator properly imported.

The import correctly exposes the new gen_gemm_sm100_module_cutlass_bf16 function from the core module.


38-38: LGTM! Export added to all.

The function is properly included in the public API surface of the jit.gemm module.

flashinfer/jit/gemm/core.py (1)

193-236: LGTM! BF16 Cutlass JIT generator follows established patterns.

The function correctly mirrors the existing FP8/FP4 generation workflow, rendering templates for multiple CTA configurations and returning a properly configured JIT spec.

The static analysis tool suggests using list unpacking ([*nvcc_flags, "-DENABLE_BF16"]) instead of concatenation at line 232, but the current approach is consistent with the rest of the codebase and works correctly.

csrc/bf16_gemm_cutlass.cu (5)

40-42: LGTM! Explicit template instantiations are correct.

The explicit instantiations for CutlassBf16GemmRunner<__nv_bfloat16> and CutlassBf16GemmRunner<half> ensure the template implementations are available for linking.


49-58: LGTM! Config selection with proper bounds checking.

The getBf16GemmConfig function correctly:

  • Lazily initializes a static vector of configs
  • Validates the tactic index is within bounds
  • Provides a clear error message on out-of-bounds access

60-83: LGTM! Workspace management correctly handles both cases.

The runGemm function properly:

  • Computes required workspace size
  • Allocates temporary workspace if the provided buffer is insufficient
  • Reuses the provided workspace when adequate
  • Passes correct parameters to the GEMM runner

85-140: LGTM! Input validation and dispatch logic is sound.

The bf16_bmm_impl function correctly:

  • Validates input dtypes
  • Handles both 2D (matrix) and 3D (batched) inputs
  • Checks dimension compatibility with clear error messages
  • Validates output shape and dtype
  • Dispatches to the appropriate template instantiation based on output dtype

Note: The past review mentioned a non-contiguous mat2 issue. This is correctly handled in the Python layer (flashinfer/gemm/gemm_base.py) where the caller ensures contiguity before passing to the FFI.


144-161: LGTM! FFI exports are properly defined.

The public bf16_gemm and bf16_gemm_tactic_num functions are correctly exposed via TVM FFI macros, enabling Python-side access to the BF16 GEMM functionality.

include/flashinfer/gemm/bf16_gemm_template_sm100.h (4)

46-64: LGTM! SM type adapters correctly configured.

The SMTypeAdapter specializations properly map _1SM and _2SM to their respective:

  • Scale factors (1 and 2)
  • Epilogue schedules (TmaWarpSpecialized1Sm and TmaWarpSpecialized2Sm)
  • Mainloop schedules (KernelTmaWarpSpecialized1SmSm100 and KernelTmaWarpSpecialized2SmSm100)

66-153: LGTM! Launcher correctly configures Cutlass GEMM.

The genericBf16GemmKernelLauncherSm100 function properly:

  • Defines element types, layouts, and alignments for A, B, C, D
  • Handles conditional type selection for BF16 support
  • Builds CollectiveEpilogue and CollectiveMainloop with appropriate tile and cluster shapes
  • Constructs stride descriptors for batched operation
  • Configures fusion arguments (alpha/beta)

183-190: LGTM! Macro correctly instantiates BF16 GEMM templates.

The INSTANCE_BF16_GEMM_TEMPLATE_SM100 macro properly expands to explicit template instantiations for the launcher with specified tile and cluster configurations.


154-178: The workspace query mechanism is working as designed.

The if (!A && !B && !D) pattern at lines 154-155 is the established, intentional design across all GEMM implementations (bf16, fp8, fp4). When probing for workspace size, getWorkspaceSizeImpl() explicitly calls dispatchToArch() with nullptr for all data pointers (A, B, D) along with nullptr for workspacePtr and workspaceBytes=0. The data pointer check correctly detects and handles this probe scenario, returning the required workspace size without triggering subsequent validation. Checking workspacePtr instead would not improve robustness, as the pointer is part of the function signature regardless. The pattern's correctness is further validated by the error-handling comments in similar implementations that explicitly document swallowing SMEM constraint errors during configuration probing.

Likely an incorrect or invalid review comment.

include/flashinfer/gemm/bf16_gemm_cutlass_template.h (7)

44-45: LGTM! Forward declarations for SM types.

The forward declarations of _1SM and _2SM are correct and used by the SMTypeAdapter specializations.


47-52: LGTM! Launcher declaration is well-formed.

The genericBf16GemmKernelLauncherSm100 function signature correctly declares all necessary template parameters and runtime arguments for the SM100 BF16 GEMM launcher.


54-91: LGTM! Cluster shape dispatch covers all supported configurations.

The dispatchGemmClusterShapeSm100 function correctly routes to the appropriate launcher based on the cluster shape (1x1x1, 1x2x1, 1x4x1, 2x1x1, 2x2x1), with proper SM type selection (_1SM or _2SM) and error handling for unsupported shapes.


93-125: LGTM! Tile configuration dispatch handles all SM100 tile sizes.

The dispatchToArch function correctly:

  • Routes to the appropriate tile configuration based on tile_config_sm100
  • Swaps A/B matrices and m/n dimensions to match column-major layout requirements
  • Throws for unsupported tile configurations

127-134: LGTM! GEMM entry point correctly delegates to arch dispatch.

The CutlassBf16GemmRunner<T>::gemm implementation properly forwards all parameters to the architecture-specific dispatch function.


136-151: LGTM! Workspace size probing handles SMEM constraint failures.

The getWorkspaceSizeImpl function correctly:

  • Probes all available GEMM configurations
  • Catches and silently ignores std::runtime_error exceptions when configurations exceed SMEM limits
  • Returns the maximum workspace size across all valid configurations

The comment "Swallow errors when SMEM exceeds maximum allowed" clearly documents this intentional behavior. Based on learnings from the FP8 implementation, this pattern is acceptable for configuration discovery.


178-202: LGTM! Config generation covers all tile and cluster combinations.

The getConfigs function correctly:

  • Enumerates all 5 SM100 tile configurations (64x64x128, 64x128x128, 64x256x128, 128x64x128, 128x128x128)
  • Combines with all 5 cluster shapes (1x1x1, 1x2x1, 1x4x1, 2x1x1, 2x2x1)
  • Creates 25 candidate configurations for autotuning
  • Uses AUTO schedule types for flexibility
flashinfer/gemm/gemm_base.py (3)

276-384: Well-structured BF16 MM API with proper backend routing.

The mm_bf16 function follows established patterns with proper decorator usage (@backend_requirement, @flashinfer_api), clear documentation, and correct backend selection logic. The integration with the autotuner via bf16_gemm_sm100 is well-implemented.


433-510: Clean BMM BF16 implementation ready for backend expansion.

The bmm_bf16 function is well-structured with proper validation and workspace management. The design anticipates future cuDNN backend support (as noted in the author's comment at line 438), making it extensible.


762-796: CUTLASS BF16 runner implementation is consistent with FP8 pattern.

The CutlassBf16GemmRunner follows the same structure as the FP8 runner, with proper tactic enumeration and tensor unpacking. The b.transpose(-2, -1) pattern on line 785 is consistent with the FP8 implementation (line 726).

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good so far

@raayandhar
Copy link
Copy Markdown
Contributor Author

@aleozlx could you trigger CI now that the checks have passed? (if you want others to review just ignore)

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !226 has been created, and the CI pipeline #41300448 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @raayandhar, this will be a great addition to FlashInfer.

To facilitate benchmarking, would it be possible to add in this or a subsequent PR, benchmarking support in flashinfer_benchmark.py? A reference benchmark routine for bmm_fp8 can be found here.

One strength of the flashinfer_benchmark.py microbenchmark harness is its ability to output a structured csv for results and compare backends' performances with the same inputs. As we add mm_bf16 and bmm_bf16 APIs, adding them to the benchmark would be helpful in performance tracking.

Comment thread flashinfer/gemm/gemm_base.py Outdated
@raayandhar
Copy link
Copy Markdown
Contributor Author

Hi @raayandhar, this will be a great addition to FlashInfer.

To facilitate benchmarking, would it be possible to add in this or a subsequent PR, benchmarking support in flashinfer_benchmark.py? A reference benchmark routine for bmm_fp8 can be found here.

One strength of the flashinfer_benchmark.py microbenchmark harness is its ability to output a structured csv for results and compare backends' performances with the same inputs. As we add mm_bf16 and bmm_bf16 APIs, adding them to the benchmark would be helpful in performance tracking.

Yes, I was planning on adding support in additional PR (this week probably), I spoke to @aleozlx about it already

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @tests/gemm/test_mm_bf16.py:
- Line 25: The code calls get_compute_capability(torch.device(device="cuda"))
but references an undefined variable named device; update the call to use a
string literal instead (e.g., torch.device("cuda") or simply "cuda") so
get_compute_capability receives a valid device object/string; locate the call to
get_compute_capability and change the torch.device invocation to remove the
undefined named argument.
🧹 Nitpick comments (2)
flashinfer/gemm/gemm_base.py (2)

1650-1657: Consider deduplicating output dtype validation.

The _validate_bf16_output_dtype function (lines 1650-1657) is identical to _validate_fp8_output_dtype (lines 1641-1648). Both validate that output dtype is either torch.bfloat16 or torch.float16.

♻️ Optional refactor to reduce duplication
-def _validate_bf16_output_dtype(dtype: torch.dtype):
-    """Validate that the output dtype is either bf16 or fp16."""
-    if dtype not in (torch.bfloat16, torch.float16):
-        raise ValueError(
-            f"Unsupported output dtype: {dtype}. "
-            f"Only torch.bfloat16 and torch.float16 are supported for BF16 GEMM operations."
-        )
+def _validate_gemm_output_dtype(dtype: torch.dtype, operation: str = "GEMM"):
+    """Validate that the output dtype is either bf16 or fp16."""
+    if dtype not in (torch.bfloat16, torch.float16):
+        raise ValueError(
+            f"Unsupported output dtype: {dtype}. "
+            f"Only torch.bfloat16 and torch.float16 are supported for {operation} operations."
+        )
+
+def _validate_fp8_output_dtype(dtype: torch.dtype):
+    """Validate that the output dtype is either bf16 or fp16."""
+    _validate_gemm_output_dtype(dtype, "FP8 GEMM")
+
+def _validate_bf16_output_dtype(dtype: torch.dtype):
+    """Validate that the output dtype is either bf16 or fp16."""
+    _validate_gemm_output_dtype(dtype, "BF16 GEMM")

3897-3897: Trailing blank line in function.

Line 3897 has a blank line that appears to be unintentional within the function body. This is a minor cosmetic issue.

♻️ Clean up trailing blank line
     if out is None:
         out_dtype = out_dtype or torch.bfloat16
         out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device)
-
     m_grouped_fp8_gemm_nt_contiguous(
         (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk
     )
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b3f85f and 9112fa3.

📒 Files selected for processing (3)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_bmm_bf16.py
  • tests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gemm/test_bmm_bf16.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/gemm/test_mm_bf16.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/gemm/gemm_base.py
🧠 Learnings (13)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/gemm/test_mm_bf16.py
  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`

Applied to files:

  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (2)
tests/gemm/test_mm_bf16.py (5)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • flashinfer (41-125)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • flashinfer (26-60)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (283-382)
flashinfer/utils.py (1)
  • get_compute_capability (258-261)
flashinfer/gemm/gemm_base.py (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
csrc/bf16_gemm_cutlass.cu (2)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py

187-187: Unused function argument: a

(ARG001)


188-188: Unused function argument: b

(ARG001)


189-189: Unused function argument: out

(ARG001)


193-193: Unused function argument: backend

(ARG001)


196-198: Avoid specifying long messages outside the exception class

(TRY003)


200-202: Avoid specifying long messages outside the exception class

(TRY003)


211-211: Unused function argument: a

(ARG001)


212-212: Unused function argument: b

(ARG001)


213-213: Unused function argument: out

(ARG001)


215-215: Unused function argument: bias

(ARG001)


216-216: Unused function argument: pdl

(ARG001)


217-217: Unused function argument: backend

(ARG001)


220-222: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Unused function argument: pdl

(ARG001)


231-231: Unused function argument: out

(ARG001)


232-232: Unused function argument: out_dtype

(ARG001)


233-233: Unused function argument: backend

(ARG001)


236-238: Avoid specifying long messages outside the exception class

(TRY003)


240-242: Avoid specifying long messages outside the exception class

(TRY003)


245-247: Avoid specifying long messages outside the exception class

(TRY003)


254-254: Unused function argument: a

(ARG001)


255-255: Unused function argument: b

(ARG001)


258-258: Unused function argument: out

(ARG001)


259-259: Unused function argument: out_dtype

(ARG001)


260-260: Unused function argument: backend

(ARG001)


353-355: Avoid specifying long messages outside the exception class

(TRY003)


357-359: Avoid specifying long messages outside the exception class

(TRY003)


361-363: Avoid specifying long messages outside the exception class

(TRY003)


387-387: Unused function argument: A

(ARG001)


388-388: Unused function argument: B

(ARG001)


389-389: Unused function argument: out

(ARG001)


391-391: Unused function argument: backend

(ARG001)


401-401: Unused function argument: out

(ARG001)


402-402: Unused function argument: out_dtype

(ARG001)


403-403: Unused function argument: backend

(ARG001)


406-408: Avoid specifying long messages outside the exception class

(TRY003)


410-412: Avoid specifying long messages outside the exception class

(TRY003)


419-419: Unused function argument: A

(ARG001)


420-420: Unused function argument: B

(ARG001)


421-421: Unused function argument: out

(ARG001)


422-422: Unused function argument: out_dtype

(ARG001)


423-423: Unused function argument: backend

(ARG001)


444-444: Unused function argument: backend

(ARG001)


492-494: Avoid specifying long messages outside the exception class

(TRY003)


496-498: Avoid specifying long messages outside the exception class

(TRY003)


500-502: Avoid specifying long messages outside the exception class

(TRY003)


768-768: Unused method argument: inputs

(ARG002)


769-769: Unused method argument: profile

(ARG002)


777-777: Unused method argument: do_preparation

(ARG002)


778-778: Unused method argument: kwargs

(ARG002)


1066-1068: Avoid specifying long messages outside the exception class

(TRY003)


1070-1072: Avoid specifying long messages outside the exception class

(TRY003)


1074-1076: Avoid specifying long messages outside the exception class

(TRY003)


1653-1656: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/gemm/gemm_base.py (1)

370-379: Verify backend parameter handling in heuristic calls.

When backend="cutlass" or backend="tgv", the code calls _heuristic_func_mm_bf16 with hardcoded values (None, False for cutlass on line 372, and bias, pdl for tgv on line 376). However, the function signature and validation should have already rejected invalid combinations (e.g., cutlass with bias).

This pattern seems intentional but could be clarified. If a user somehow bypasses validation and calls with backend="cutlass", bias=<tensor>, line 372 silently ignores the bias rather than raising an error.

Based on learnings and the @backend_requirement decorator behavior, the requirement functions should prevent invalid combinations from reaching this code. However, consider whether explicit parameter passing (rather than hardcoding) would be clearer:

elif backend == "cutlass":
    backends = _heuristic_func_mm_bf16(
        ["cutlass"], a, b, bias, pdl, out, out_dtype, backend
    )

This would make the heuristic function responsible for filtering, which it already does.

Comment thread tests/gemm/test_mm_bf16.py
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #41300448: 4/20 passed

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Jan 8, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !226 has been updated with latest changes, and the CI pipeline #41366487 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #41366487: canceled

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Unit tests says cancelled passed on all key SKUs.

@raayandhar
Copy link
Copy Markdown
Contributor Author

@bkryu @aleozlx anything left before we can merge?

@aleozlx aleozlx enabled auto-merge (squash) January 9, 2026 20:59
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 9, 2026

oh i see, just needs a code owner approval
@jimmyzho or @yzh119 could you help us merge it in? thanks

Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just small nit, LGTM!

if out.dtype != out_dtype:
raise ValueError(
f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] can you move these checks to _check_mm_bf16_problem_size?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'll do that in my cuDNN PR.

@aleozlx aleozlx merged commit 2062dec into flashinfer-ai:main Jan 10, 2026
12 checks passed
@aidando73
Copy link
Copy Markdown
Contributor

I just picked tile sizes/cluster shapes that compiled from the FP8 implementations. Maybe we can get better performance with more tile sizes/cluster shapes? Especially targeting smaller batch sizes.

@raayandhar one of the cutlass maintainers mentioned about a transpose trick you can do when M is small NVIDIA/cutlass#2923 (comment) - with stream_k and split_k - wondering if you tried that?

I'm currently taking a look at the block-wise fp8 kernel right now NVIDIA/cutlass#2923 <- lmk if you had any context that would be useful for me here. The cutlass backend performs pretty bad for smaller batch sizes / shapes.

@raayandhar
Copy link
Copy Markdown
Contributor Author

@raayandhar one of the cutlass maintainers mentioned about a transpose trick you can do when M is small NVIDIA/cutlass#2923 (comment) - with stream_k and split_k - wondering if you tried that?

I'm currently taking a look at the block-wise fp8 kernel right now NVIDIA/cutlass#2923 <- lmk if you had any context that would be useful for me here. The cutlass backend performs pretty bad for smaller batch sizes / shapes.

Unfortunately no, I did not try that trick. But its something I could try in the future; I'll be following that issue thread.

Yeah I guess if you looked at the benchmark script from earlier in this PR I also observed that at lower batch sizes we were not performing very well. But unfortunately I don't have much more context than that. Maybe @aleozlx @bkryu might know more?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants