feat: BF16 GEMM using CUTLASS backend for SM100#2070
feat: BF16 GEMM using CUTLASS backend for SM100#2070aleozlx merged 18 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds end-to-end SM100 BF16 GEMM: Cutlass-based CUDA runner and TVM FFI, SM100 Jinja codegen and kernel templates, public C++ runner headers, Python mm_bf16/bmm_bf16 APIs with JIT/autotune integration, tests, and docs; includes workspace sizing, tactic selection, and runtime validation. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as Python caller (mm_bf16 / bmm_bf16)
participant API as flashinfer.gemm API
participant Orch as bf16_gemm_sm100 (orchestrator)
participant JIT as JIT module / generator
participant FFI as CUDA FFI (bf16_gemm)
participant Runner as CutlassBf16GemmRunner
participant Kernel as Cutlass kernel
Py->>API: call mm_bf16/bmm_bf16(a,b,opts)
API->>API: validate dtypes/shapes, prepare out & workspace
API->>Orch: bf16_gemm_sm100(a,b,out,workspace,runner_names)
Orch->>JIT: ensure/load SM100 BF16 module
JIT-->>Orch: module + FFI bindings
Orch->>FFI: call bf16_gemm(..., tactic)
FFI->>Runner: getBf16GemmConfig(m,n,k,tactic)
FFI->>Runner: runGemm<T>(...) -> compute workspace, call gemm
Runner->>Kernel: launch Cutlass kernel on CUDA stream
Kernel-->>Runner: kernel completes
Runner-->>FFI: return status
FFI-->>Orch: done
Orch-->>API: completed
API-->>Py: return tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Currently there is an error about the second matrix being non-contiguous: |
dd6216f to
aaaee56
Compare
There was a problem hiding this comment.
Actionable comments posted: 6
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(497-520)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (2)
gen_gemm_sm100_module(240-316)gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(250-313)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(250-313)mm_bf16(183-246)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-246)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
flashinfer(41-145)gemm(42-95)std(184-184)std(185-185)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(497-520)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
CutlassBf16GemmRunnerInterface(29-41)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)csrc/tvm_ffi_utils.h (2)
get_stream(272-274)encode_dlpack_dtype(29-31)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 GitHub Actions: pre-commit
flashinfer/__init__.py
[error] 88-88: mypy: Module "flashinfer.gemm" has no attribute "bmm_bf16".
[error] 90-90: mypy: Module "flashinfer.gemm" has no attribute "mm_bf16".
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
docs/api/gemm.rst (2)
10-17: Documentation formatting is consistent and well-structured.The BF16 GEMM subsection follows the established pattern of other GEMM sections in the file (consistent indentation, autosummary directive, toctree configuration). Placement at the beginning of the GEMM API documentation is logical and appropriate.
10-17: Documentation is complete and accurate.The BF16 GEMM subsection correctly documents
mm_bf16andbmm_bf16—these are the only public-facing BF16 GEMM functions (verified by top-level exports inflashinfer/__init__.py). Thebf16_gemmmentioned in the PR summary is an internal C++ binding and tuning identifier, not a public Python API.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
512-520: Materialize the transposed tensor before passing to the CUTLASS runner.This is the root cause of the runtime error reported in the PR:
b.transpose(-2, -1)returns a non-contiguous view, but the C++ binding requires contiguous input. The fix is to call.contiguous()on the transposed tensor.Apply this fix:
a, b, out, workspace_buffer = inputs module.bf16_gemm( a, - b.transpose(-2, -1), + b.transpose(-2, -1).contiguous(), out, workspace_buffer, tactic, )This issue was already identified in the previous review.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (7)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
flashinfer/gemm/__init__.py (1)
1-38: LGTM! Public API exports are correctly wired.The new BF16 GEMM functions (
bmm_bf16andmm_bf16) are properly imported fromgemm_baseand exposed through the module's__all__list, making them available as part of the public API.
511d8e0 to
fbe5723
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
156-173: Restore workspace probe handling before launching kernels
CutlassBf16GemmRunner::getWorkspaceSize()calls this launcher withworkspacePtr == nullptrto learn how big the buffer must be. Today we immediately throw becauseworkspaceBytes == 0, so the probe reports 0 and the next real launch still fails with “insufficient workspace”. Please short‑circuit the probe and return the computed size instead of throwing.size_t workspace_size = gemm.get_workspace_size(arguments); + if (workspacePtr == nullptr) { + return workspace_size; + } if (workspace_size > workspaceBytes) {flashinfer/gemm/gemm_base.py (4)
182-246: Validate BF16 MM inputs before dispatchThe CUTLASS path assumes 2‑D bf16 matrices on the same CUDA device with contiguous row‑major layout. Without the early guards we can accept the wrong dtype, mismatched shapes/devices, or a strided view and only fail deep inside the kernel (or produce garbage). Please restore the validation/contiguity fixes before touching the workspace.
- if backend != "cutlass": + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError("Only bf16 and fp16 outputs are supported.") + + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch for matrix multiplication. a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + if out is not None and not out.is_contiguous(): + raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")
249-313: Do the same validation for BMMThe batched entry point has the same holes: wrong dtype, rank, device, or non‑contiguous slices go straight into CUTLASS and fail later (or worse, corrupt results). Please add the missing checks for 3‑D tensors, matching batch/K dims, same device, and enforce contiguity before launching.
- if backend != "cutlass": + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError("Only bf16 and fp16 outputs are supported.") + + if a.ndim != 3 or b.ndim != 3: + raise ValueError(f"bmm_bf16 expects 3D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[0] != b.shape[0]: + raise ValueError( + f"Batch size mismatch. a.shape[0]={a.shape[0]} must equal b.shape[0]={b.shape[0]}." + ) + if a.shape[2] != b.shape[1]: + raise ValueError( + f"K dimension mismatch. a.shape[2]={a.shape[2]} must equal b.shape[1]={b.shape[1]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + if out is not None and not out.is_contiguous(): + raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")
512-520: Materialize column‑major B before calling the kernel
bf16_gemmstill receivesb.transpose(-2, -1)directly, which is a non‑contiguous view and reproduces the runtime error (“mat2 must be contiguous”). Please allocate the column‑major buffer before dispatching to CUTLASS.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )
590-592: Report the actual device when no runner is foundWhen
alives on a non‑default GPU,torch.device("cuda")queries device 0 and we raise “sm100” even if the tensor was on sm90. Use the tensor’s device so the error reflects reality.- major, minor = get_compute_capability(torch.device("cuda")) + major, minor = get_compute_capability(a.device)
🧹 Nitpick comments (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
54-91: Cluster shape dispatch with limited configuration support.The function correctly dispatches based on cluster shape, with appropriate error handling for unsupported configurations. The limitation to only
ClusterShape_1x1x1aligns with the PR author's note about tile size and SMEM constraints during initial development.Note: Line 66 has a
breakstatement afterreturn, which is unreachable but harmless.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (6)
- tests/gemm/test_mm_bf16.py
- flashinfer/gemm/init.py
- csrc/bf16_gemm_cutlass.jinja
- tests/gemm/test_bmm_bf16.py
- flashinfer/jit/gemm/init.py
- csrc/bf16_gemm_cutlass.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(250-313)mm_bf16(183-246)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (5)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
29-41: Well-designed interface for BF16 GEMM runner.The abstract interface provides a clean contract with appropriate virtual methods for GEMM operations, workspace management, and configuration enumeration. The virtual destructor is correctly included for safe polymorphic deletion.
43-57: Template class declaration follows proper separation pattern.The template class declaration correctly inherits from the interface and overrides all pure virtual methods. The separation of declaration (here) and definition (in the template header) is appropriate for template code.
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
19-34: Appropriate diagnostic pragmas for CUTLASS integration.The GCC diagnostic pragmas correctly suppress strict-aliasing warnings around CUTLASS headers, which is necessary since CUTLASS may use type punning internally.
136-143: GEMM implementation correctly delegates to dispatch logic.The implementation properly forwards all parameters to
dispatchToArchwith appropriate type casting.
186-210: Configuration enumeration with limited initial support.The function correctly enumerates candidate configurations by combining tile configs and cluster shapes. The current limitation to a single tile configuration (
CtaShape64x64x128B) and cluster shape (ClusterShape_1x1x1) aligns with the PR objectives and the author's noted constraints regarding SMEM space and limited B200 hardware access for testing.As additional tile sizes and cluster shapes are validated on SM100 hardware, uncomment the relevant lines to expand the configuration space.
99-103: Verify the intentional A↔B and m↔n parameter swap is correct for the kernel expectations.The parameter swap pattern
dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(B, A, static_cast<T*>(D), n, m, k, ...)is applied consistently across all tile configurations in both bf16 and fp8 GEMM implementations. This is paired with explicit layout declarations: LayoutA = RowMajor and LayoutB = ColumnMajor.While the consistency of this pattern across multiple files strongly suggests it is intentional for layout conversion, please confirm that this parameter reordering matches the actual kernel signature and expectations for dispatchGemmClusterShapeSm100.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
166-173: Improve MNK hash mixing to avoid collisions.XORing
h1 ^ h2 ^ h3collapses different(m,n,k)permutations to the same bucket, so cached workspace sizes can be reused for incompatible shapes. Combine the hashes with a proper mixer instead.struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + size_t seed = h1; + seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + return seed; } };
175-183: Guard the static workspace cache with a mutex.
workspace_hashmapis mutated without synchronization; concurrent calls togetWorkspaceSizewill race onfind()/operator[]. Protect the cache with a lock.- static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::mutex workspace_mutex; size_t workspace_size = 0; - if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { - workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); - workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; - } else { - workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; - } + const MNK key = std::make_tuple(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it != workspace_hashmap.end()) { + return it->second; + } + workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap.emplace(key, workspace_size); + } return workspace_size;tests/gemm/test_mm_bf16.py (1)
14-21: Skip on CPU-only test environments.
get_compute_capability(torch.device("cuda"))raises when CUDA isn’t available, causing the entire suite to error out instead of skipping. Guard this withif not torch.cuda.is_available(): pytest.skip(...)before the capability query.tests/gemm/test_bmm_bf16.py (1)
15-22: Gracefully skip when CUDA is unavailable.Like the MM test, calling
get_compute_capability(torch.device("cuda"))without checkingtorch.cuda.is_available()hard-fails on CPU-only setups. Add a skip guard before querying the device.flashinfer/gemm/gemm_base.py (3)
217-240: Validate inputs before firing the kernel.
mm_bf16still accepts tensors with wrong dtype, shape, or device, which the CUTLASS runner interprets incorrectly (e.g., passing fp16 data corrupts results). Add explicit checks for bf16 dtype, matching inner dimensions, and matching devices at the top of the function so misuse fails fast.+ if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError( + f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}." + ) + if a.ndim != 2 or b.ndim != 2: + raise ValueError( + f"Inputs must be 2D matrices. Got a.ndim={a.ndim}, b.ndim={b.ndim}." + ) + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError( + f"Device mismatch: a.device={a.device}, b.device={b.device}." + )
288-307: Add basic sanity checks for batched inputs.
bmm_bf16also needs dtype/shape/device validation; otherwise mismatched batch sizes or wrong K dimensions surface as low-level CUTLASS failures. Please mirror the checks frommm_bf16for 3D tensors (batch, m, k) and (batch, k, n), ensuring matching batch/K dimensions and bf16 dtype.+ if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError( + f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}." + ) + if A.ndim != 3 or B.ndim != 3: + raise ValueError( + f"Inputs must be 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}." + ) + if A.shape[0] != B.shape[0]: + raise ValueError( + f"Batch mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}." + ) + if A.shape[2] != B.shape[1]: + raise ValueError( + f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError( + f"Device mismatch: A.device={A.device}, B.device={B.device}." + )
512-519: Make the transposed B operand contiguous.The CUTLASS binding now enforces
mat2.is_contiguous(). Transposing on the fly hands it a strided view and triggers the runtime error you reported. Materialize the column-major buffer before launching.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
flashinfer/gemm/gemm_base.py(4 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.563Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_cutlass_template.hflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
flashinfer(41-145)gemm(42-95)std(184-184)std(185-185)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
tests/gemm/test_mm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-246)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(250-313)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
d2c8547 to
8a58e45
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
156-170: Allow workspace probes to succeed when no buffer is provided.
CutlassBf16GemmRunner::getWorkspaceSizeImplinvokes this launcher withworkspacePtr == nullptrandworkspaceBytes == 0to query the required size. The current code throws before returning the computedworkspace_size, breaking workspace queries. Short-circuit whenworkspacePtrisnullptrto return the size without running the kernel.size_t workspace_size = gemm.get_workspace_size(arguments); + if (workspacePtr == nullptr) { + return workspace_size; + } if (workspace_size > workspaceBytes) { throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); }include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
166-173: Hash function prone to collisions.The
MNKHashfunction uses XOR to combine hash values (h1 ^ h2 ^ h3), which produces collisions for permutations of the same values. For example,(1, 2, 3)and(3, 2, 1)hash identically, potentially returning incorrect workspace sizes.Use a proper hash combining algorithm:
struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + // Combine hashes properly to avoid collisions + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } };
175-184: Critical: Data race on static workspace cache.The static
workspace_hashmapat Line 175 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe initialization of function-local statics, concurrent access viafind()(Line 178) andoperator[](Lines 180, 182) creates data races ifgetWorkspaceSizeis called from multiple threads.Protect the map with a mutex:
+ static std::mutex workspace_mutex; static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; size_t workspace_size = 0; + std::lock_guard<std::mutex> lock(workspace_mutex); if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {Alternatively, use
std::shared_mutexwith shared (read) and exclusive (write) locking for better concurrent read performance.tests/gemm/test_bmm_bf16.py (1)
15-22: Guard test behind CUDA availability.Calling
get_compute_capability(torch.device("cuda"))without first checkingtorch.cuda.is_available()will raise an exception on non-CUDA systems instead of skipping gracefully.Add an early CUDA check:
def test_bmm_bf16(b, m, n, k, res_dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") compute_capability = get_compute_capability(torch.device(device="cuda"))flashinfer/gemm/gemm_base.py (3)
182-245: Add input validation for dtype, shape, and device consistency.The function is missing essential input validation that could lead to cryptic errors downstream. Per previous review feedback, please add checks at the beginning of the function:
+ if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch for matrix multiplication. " + f"a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if backend != "cutlass":
248-312: Add input validation for dtype, shape, and device consistency.Similar to
mm_bf16, this function lacks essential input validation. Per previous review feedback, please add checks at the beginning:+ if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.") + if A.ndim != 3 or B.ndim != 3: + raise ValueError(f"Expected 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.") + if A.shape[0] != B.shape[0]: + raise ValueError( + f"Batch size mismatch. A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}." + ) + if A.shape[2] != B.shape[1]: + raise ValueError( + f"Shape mismatch for batched matrix multiplication. " + f"A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.") + if backend != "cutlass":
511-519: Make the B operand contiguous before invoking the CUTLASS runner.
transpose(-2, -1)returns a non-contiguous view, which causes the runtime error you reported: "RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous". Per previous review feedback, materialize the column-major buffer before launching the kernel:+ b_col_major = b.transpose(-2, -1).contiguous() module.bf16_gemm( a, - b.transpose(-2, -1), + b_col_major, out, workspace_buffer, tactic, )
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
187-211: Config enumeration is clean but limited to one configuration.The
getConfigsimplementation only enumeratesCtaShape64x64x128BandClusterShape_1x1x1, reflecting the WIP status and SMEM constraints. The nested loop pattern is extensible for adding more configs once SMEM issues are resolved.Would you like help generating a script to analyze SMEM usage across different tile configurations to understand which sizes are viable for SM100?
flashinfer/jit/gemm/core.py (1)
193-237: WIP tile configurations are appropriate for initial testing.The implementation correctly follows the established pattern from FP8/FP4 modules. The single active tile configuration (64, 64, 128) is a reasonable conservative choice while debugging SMEM constraints on SM100 hardware, especially given your limited B200 access.
Optional style improvement (flagged by static analysis):
- return gen_jit_spec( - "bf16_gemm_cutlass", - source_paths, - extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], - extra_cflags=[ - "-DFAST_BUILD", - ], - ) + return gen_jit_spec( + "bf16_gemm_cutlass", + source_paths, + extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"], + extra_cflags=[ + "-DFAST_BUILD", + ], + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/bf16_gemm_cutlass.cu
- tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/bf16_gemm_cutlass.jinjainclude/flashinfer/gemm/bf16_gemm_template_sm100.hinclude/flashinfer/gemm/bf16_gemm_cutlass_template.hflashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/bf16_gemm_cutlass.hflashinfer/__init__.py
🧬 Code graph analysis (9)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-786)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(496-519)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(249-312)mm_bf16(183-245)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(249-312)mm_bf16(183-245)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)
tests/gemm/test_bmm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
bmm_bf16(249-312)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py
217-217: Avoid specifying long messages outside the exception class
(TRY003)
219-219: Avoid specifying long messages outside the exception class
(TRY003)
229-231: Avoid specifying long messages outside the exception class
(TRY003)
233-235: Avoid specifying long messages outside the exception class
(TRY003)
237-239: Avoid specifying long messages outside the exception class
(TRY003)
283-283: Avoid specifying long messages outside the exception class
(TRY003)
285-285: Avoid specifying long messages outside the exception class
(TRY003)
296-298: Avoid specifying long messages outside the exception class
(TRY003)
300-302: Avoid specifying long messages outside the exception class
(TRY003)
304-306: Avoid specifying long messages outside the exception class
(TRY003)
499-499: Unused method argument: inputs
(ARG002)
500-500: Unused method argument: profile
(ARG002)
508-508: Unused method argument: do_preparation
(ARG002)
509-509: Unused method argument: kwargs
(ARG002)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (12)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
1-62: LGTM! Clean interface/implementation pattern.The abstract interface and templated concrete class follow best practices for extensibility. The separation of public getWorkspaceSize and private getWorkspaceSizeImpl suggests proper encapsulation of workspace size computation logic.
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
46-151: LGTM! Standard CUTLASS GEMM setup.The SMTypeAdapter specializations and launcher configuration follow CUTLASS patterns correctly. Regarding the comment on Line 147: setting
fusion_args.alpha = 1.0fandfusion_args.beta = 0.0fis the standard way to configure a GEMM epilogue forD = A*B(no accumulation). This is the right approach.include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
136-143: LGTM! Clean forwarding to dispatcher.
145-160: Exception handling is appropriate for config probing.The pattern of catching and ignoring
std::runtime_errorwhen probing workspace sizes is acceptable, as some configurations may legitimately fail due to SMEM constraints. The comment on Line 155 documents the rationale clearly.Based on learnings
44-134: No changes needed—review comment is accurate.Verification confirms the bf16 dispatcher is intentionally limited to
CtaShape64x64x128BandClusterShape_1x1x1(lines 100–103 and 65–67), while other tile configs and cluster shapes remain commented out. This differs from the fp8 implementation, which enables multiple configurations, confirming the bf16 limitation is deliberate due to SMEM constraints as noted. The transpose pattern (swapping B, A and n, m at line 101–102) is correct for layout handling.docs/api/gemm.rst (1)
10-18: LGTM! Documentation follows existing patterns.The new BF16 GEMM section properly documents the
mm_bf16andbmm_bf16entry points, following the same autosummary format as other GEMM types in this file.flashinfer/__init__.py (1)
88-90: LGTM! BF16 GEMM exports are now available.The imports of
bmm_bf16andmm_bf16from thegemmmodule expose the new BF16 GEMM functionality at the top level. Past review comments indicate the necessary exports were added toflashinfer/gemm/__init__.py.flashinfer/jit/gemm/__init__.py (1)
22-22: LGTM! JIT generator export follows existing patterns.The
gen_gemm_sm100_module_cutlass_bf16import and export are consistent with other GEMM generators in this module.Also applies to: 37-37
flashinfer/gemm/__init__.py (1)
2-2: LGTM! GEMM module exports are properly configured.The
bmm_bf16andmm_bf16imports fromgemm_baseand their inclusion in__all__enable the top-level imports inflashinfer/__init__.pyto work correctly.Also applies to: 4-4, 25-25, 27-27
tests/gemm/test_bmm_bf16.py (1)
23-34: LGTM! Test logic is sound.The test correctly creates BF16 inputs, computes a reference with
torch.bmm, and validates thebmm_bf16output using cosine similarity. The threshold of 0.99 is reasonable for BF16 precision.csrc/bf16_gemm_cutlass.jinja (1)
1-27: LGTM! Clean template structure with conservative defaults.The template correctly instantiates the SM100 BF16 GEMM kernel with a single-SM cluster configuration (1,1,1), which is appropriate for initial testing. The commented-out multi-SM cluster configurations provide clear guidance for future performance tuning once the basic implementation is validated.
flashinfer/gemm/gemm_base.py (1)
577-616: LGTM! AutoTuner integration follows established patterns.The function correctly uses
a.devicefor SM version checking and properly integrates with the AutoTuner for dynamic tactic selection. The tuning configuration appropriately profiles on the M dimension using power-of-2 bucketing, matching the pattern used infp8_gemm_sm100.
dcbc17a to
28baee5
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (7)
tests/gemm/test_mm_bf16.py (2)
25-31: Use row‑major (k, n) weight and avoid.Tto keep inputs contiguous.Generate
mat2as(k, n), use it directly in both the reference and API call. This prevents passing a non‑contiguous transpose and matches the documented contract.- mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) - - reference = torch.mm(input, mat2.T) + mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + reference = torch.mm(input, mat2) ... - mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype) + mm_bf16(input, mat2, out=out, out_dtype=res_dtype)
14-16: Skip on CPU-only to avoid hard failure.Add CUDA-availability guard before calling
get_compute_capability.def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): - compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + compute_capability = get_compute_capability(torch.device(device="cuda"))tests/gemm/test_bmm_bf16.py (1)
14-16: Skip on CPU-only to avoid hard failure.Guard
get_compute_capability(torch.device("cuda"))with a CUDA-availability check so the test skips instead of crashing on CPU-only runners.def test_bmm_bf16(b, m, n, k, res_dtype): - compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + compute_capability = get_compute_capability(torch.device(device="cuda"))flashinfer/gemm/gemm_base.py (3)
182-205: Add essential input validation (dtype, shape, device).Fail fast with clear errors to avoid cryptic backend failures.
def mm_bf16( a: torch.Tensor, b: torch.Tensor, @@ ) -> torch.Tensor: @@ - if backend != "cutlass": + # Basic validations + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
259-283: Add essential input validation (batched dtype/shape/device).Validate 3D inputs, batch, and K dims before launching the kernel.
def bmm_bf16( A: torch.Tensor, B: torch.Tensor, @@ ) -> torch.Tensor: @@ - if backend != "cutlass": + # Basic validations + if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.") + if A.ndim != 3 or B.ndim != 3: + raise ValueError(f"bmm_bf16 expects 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.") + if A.shape[0] != B.shape[0]: + raise ValueError(f"Batch size mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}.") + if A.shape[2] != B.shape[1]: + raise ValueError( + f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.") + + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
533-541: Fix runtime error: make B contiguous before calling the CUTLASS binding.
b.transpose(-2, -1)is a non‑contiguous view; the C++ binding asserts contiguity (“mat2 must be contiguous”). Materialize column‑major B.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
157-175: Thread-safety and hash quality in workspace cache.
- XOR-combining hashes collides easily.
- The static
workspace_hashmapis accessed unsafely across threads.Harden both.
@@ - struct MNKHash { + struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + // Robust hash combine to reduce collisions + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } }; @@ - static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::mutex workspace_mutex; + static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; @@ - size_t workspace_size = 0; - if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { - workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); - workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; - } else { - workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; - } - return workspace_size; + const MNK key = std::make_tuple(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it != workspace_hashmap.end()) { + return it->second; + } + } + // Compute outside lock to avoid blocking others; insert with lock. + size_t computed = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it == workspace_hashmap.end()) { + workspace_hashmap.emplace(key, computed); + return computed; + } + return it->second; + }Also add the include near the top:
-#include <stdexcept> +#include <stdexcept> +#include <mutex>
🧹 Nitpick comments (2)
flashinfer/jit/gemm/core.py (1)
229-236: Minor: prefer list splat over concatenation (RUF005).Use list unpacking for readability in
extra_cuda_cflags.- return gen_jit_spec( + return gen_jit_spec( "bf16_gemm_cutlass", source_paths, - extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], + extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"], extra_cflags=[ "-DFAST_BUILD", ], )include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
101-104: Remove unused template parameter or use it.
genericBf16GemmKernelLauncherSm100has template paramarchbut hardcodesArchTag = cutlass::arch::Sm100. Either usearchor drop the parameter.- using ArchTag = cutlass::arch::Sm100; + using ArchTag = arch;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (5)
- docs/api/gemm.rst
- csrc/bf16_gemm_cutlass.cu
- flashinfer/init.py
- flashinfer/gemm/init.py
- flashinfer/jit/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/bf16_gemm_cutlass.jinjainclude/flashinfer/gemm/bf16_gemm_template_sm100.hinclude/flashinfer/gemm/bf16_gemm_cutlass.hflashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (7)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-256)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(260-334)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-182)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(518-541)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-182)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (5)
TunableRunner(194-247)OptimizationProfile(168-183)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
gemm(44-182)_1SM(53-57)_2SM(60-64)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py
232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
228-228: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Avoid specifying long messages outside the exception class
(TRY003)
240-242: Avoid specifying long messages outside the exception class
(TRY003)
244-246: Avoid specifying long messages outside the exception class
(TRY003)
248-250: Avoid specifying long messages outside the exception class
(TRY003)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
318-320: Avoid specifying long messages outside the exception class
(TRY003)
322-324: Avoid specifying long messages outside the exception class
(TRY003)
326-328: Avoid specifying long messages outside the exception class
(TRY003)
521-521: Unused method argument: inputs
(ARG002)
522-522: Unused method argument: profile
(ARG002)
530-530: Unused method argument: do_preparation
(ARG002)
531-531: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (2)
csrc/bf16_gemm_cutlass.jinja (1)
17-26: Instantiation set looks good.Coverage of cluster shapes for 1SM/2SM variants matches the SM100 launcher; no issues spotted.
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
152-157: Good: explicit workspace probe path.Early-return on null A/B/D ensures
getWorkspaceSizeImplcan probe without needing a buffer. This unblocks tactic sizing. Based on learnings.
|
Hi experts, I think this is now ready for review! Right now we are passing all the tests that I wrote for this feature: Test Results (click to expand)The original issue (#1974) was to see if CUTLASS backend GEMM for BF16 could do better at smaller batch sizes. Now, using Benchmark Results (click to expand)but the highlight is that at a larger batch size like batch=64, we are at ~7 microseconds while TGV is at ~12.5-13 microseconds, and original/torch.compile() is at ~8 microseconds. I'm a CUTLASS newbie, so maybe adding more tile sizes and cluster shapes / autotuning wider can get even better performance, since we are slightly worse elsewhere... For reviewers:
|
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_template_sm100.h
🧬 Code graph analysis (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
183-190: LGTM! Macro enables flexible template instantiation.The
INSTANCE_BF16_GEMM_TEMPLATE_SM100macro correctly provides explicit template instantiation control for different tile sizes, cluster shapes, and SM types. The parameter list matches the launcher's template signature, and the macro will be used by the JIT generator to instantiate specific configurations.
49-50: ****The forward declarations in this file are not problematic duplicates requiring consolidation. The actual struct definitions of
_1SMand_2SMare ininclude/flashinfer/gemm/bf16_gemm_cutlass_template.h(lines 44-45), while the SM100 template files provide independent forward declarations. This is the correct C++ pattern: the base template defines the types, and SM100 template files forward-declare them to specializeSMTypeAdapter<_1SM>andSMTypeAdapter<_2SM>without incurring unnecessary includes. This separation of concerns is appropriate and consistent across all GEMM implementations (bf16, fp8, fp4).Likely an incorrect or invalid review comment.
| throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); | ||
| } | ||
|
|
||
| auto can_implement = gemm.can_implement(arguments); |
There was a problem hiding this comment.
is there any advantage to doing these safety checks this way instead of just using the CUTLASS_CHECK macro? I saw it done this way for FP8 and FP4, so I kept it this way. But just wondering because it seems the same?
1387bed to
a56d74b
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
166-176: Critical: Protect static workspace cache from concurrent access.The static
workspace_hashmapat line 166 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe static initialization, subsequent operations (find()at line 169,operator[]at lines 171 and 173) create data races whengetWorkspaceSizeis called from multiple threads, potentially causing:
- Cache corruption
- Iterator invalidation
- Undefined behavior
The BF16 GEMM APIs can be called concurrently from multiple threads in typical inference workloads, making this a critical issue.
🔎 Add mutex protection for thread safety
+ static std::mutex workspace_mutex; static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; size_t workspace_size = 0; + std::lock_guard<std::mutex> lock(workspace_mutex); if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; } else { workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; } return workspace_size;Alternatively, use
std::shared_mutexwithstd::shared_lockfor reads andstd::unique_lockfor writes to allow concurrent readers while still protecting writes.flashinfer/gemm/gemm_base.py (1)
818-826: Correct the type annotation forbiasparameter.The
biasparameter can beNone(e.g.,bmm_bf16calls this withNoneat line 509), but the type annotation saystorch.Tensor. This should beOptional[torch.Tensor]for correctness.Proposed fix
def bf16_gemm_sm100( a: torch.Tensor, b: torch.Tensor, - bias: torch.Tensor, + bias: Optional[torch.Tensor], pdl: bool, out: torch.Tensor, workspace_buffer: torch.Tensor, runner_names: List[str], ) -> None:
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
157-164: Consider a collision-resistant hash combiner.The
MNKHashfunction uses XOR to combine three hash values (h1 ^ h2 ^ h3), which can produce identical hashes for different permutations of the same values. For example,(m=1, n=2, k=3)and(m=3, n=2, k=1)would hash to the same value, potentially causing the cache to return incorrect workspace sizes.While collisions here would only cause redundant recomputation (not correctness issues), a more robust hash combiner would improve cache efficiency.
🔎 Collision-resistant hash combiner
struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + // Combine hashes to avoid collisions from permutations + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } };flashinfer/gemm/gemm_base.py (1)
219-222: Clarify the TGV output dtype restriction error message.The current error message is confusing. When
out_dtype != torch.bfloat16, it says "You cannot provide an output dtype" which is misleading—the real constraint is that TGV only supportsbfloat16output. Consider rewording for clarity:Proposed fix
if out_dtype != torch.bfloat16: raise ValueError( - "You cannot provide an output dtype to the TGV backend. Use the CUTLASS backend instead." + "TGV backend only supports bfloat16 output dtype. Use the CUTLASS backend for fp16 output." )
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/bf16_gemm_cutlass.cucsrc/bf16_gemm_cutlass.jinjadocs/api/gemm.rstflashinfer/__init__.pyflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/bf16_gemm_cutlass.hinclude/flashinfer/gemm/bf16_gemm_cutlass_template.hinclude/flashinfer/gemm/bf16_gemm_template_sm100.htests/gemm/test_bmm_bf16.pytests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/gemm/test_bmm_bf16.py
- csrc/bf16_gemm_cutlass.jinja
- include/flashinfer/gemm/bf16_gemm_cutlass.h
- flashinfer/gemm/init.py
- tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_template_sm100.hcsrc/bf16_gemm_cutlass.cuinclude/flashinfer/gemm/bf16_gemm_cutlass_template.hflashinfer/__init__.pyflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)
csrc/bf16_gemm_cutlass.cu (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
CutlassBf16GemmRunnerInterface(29-41)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)csrc/tvm_ffi_utils.h (2)
get_stream(294-296)encode_dlpack_dtype(30-32)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (5)
gemm(44-181)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(441-510)mm_bf16(285-384)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(216-397)gen_jit_spec(400-466)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
flashinfer/gemm/gemm_base.py (1)
flashinfer/utils.py (4)
supported_compute_capability(819-899)backend_requirement(902-1184)_get_cache_buf(206-217)suitable_auto_backends(1076-1096)
🪛 Ruff (0.14.10)
flashinfer/jit/gemm/core.py
232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
187-187: Unused function argument: a
(ARG001)
188-188: Unused function argument: b
(ARG001)
189-189: Unused function argument: out
(ARG001)
193-193: Unused function argument: backend
(ARG001)
196-198: Avoid specifying long messages outside the exception class
(TRY003)
200-202: Avoid specifying long messages outside the exception class
(TRY003)
211-211: Unused function argument: a
(ARG001)
212-212: Unused function argument: b
(ARG001)
213-213: Unused function argument: out
(ARG001)
215-215: Unused function argument: bias
(ARG001)
216-216: Unused function argument: pdl
(ARG001)
217-217: Unused function argument: backend
(ARG001)
220-222: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Unused function argument: pdl
(ARG001)
231-231: Unused function argument: out
(ARG001)
232-232: Unused function argument: out_dtype
(ARG001)
233-233: Unused function argument: backend
(ARG001)
236-238: Avoid specifying long messages outside the exception class
(TRY003)
240-242: Avoid specifying long messages outside the exception class
(TRY003)
245-247: Avoid specifying long messages outside the exception class
(TRY003)
255-255: Unused function argument: b
(ARG001)
258-258: Unused function argument: out
(ARG001)
259-259: Unused function argument: out_dtype
(ARG001)
260-260: Unused function argument: backend
(ARG001)
355-357: Avoid specifying long messages outside the exception class
(TRY003)
359-361: Avoid specifying long messages outside the exception class
(TRY003)
363-365: Avoid specifying long messages outside the exception class
(TRY003)
389-389: Unused function argument: A
(ARG001)
390-390: Unused function argument: B
(ARG001)
391-391: Unused function argument: out
(ARG001)
393-393: Unused function argument: backend
(ARG001)
403-403: Unused function argument: out
(ARG001)
404-404: Unused function argument: out_dtype
(ARG001)
405-405: Unused function argument: backend
(ARG001)
408-410: Avoid specifying long messages outside the exception class
(TRY003)
412-414: Avoid specifying long messages outside the exception class
(TRY003)
421-421: Unused function argument: A
(ARG001)
422-422: Unused function argument: B
(ARG001)
423-423: Unused function argument: out
(ARG001)
424-424: Unused function argument: out_dtype
(ARG001)
425-425: Unused function argument: backend
(ARG001)
446-446: Unused function argument: backend
(ARG001)
494-496: Avoid specifying long messages outside the exception class
(TRY003)
498-500: Avoid specifying long messages outside the exception class
(TRY003)
502-504: Avoid specifying long messages outside the exception class
(TRY003)
770-770: Unused method argument: inputs
(ARG002)
771-771: Unused method argument: profile
(ARG002)
779-779: Unused method argument: do_preparation
(ARG002)
780-780: Unused method argument: kwargs
(ARG002)
1068-1070: Avoid specifying long messages outside the exception class
(TRY003)
1072-1074: Avoid specifying long messages outside the exception class
(TRY003)
1076-1078: Avoid specifying long messages outside the exception class
(TRY003)
1655-1658: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (24)
docs/api/gemm.rst (1)
10-18: LGTM! Documentation structure is correct.The BF16 GEMM section is well-positioned and follows the established pattern for documenting GEMM operations in this module.
flashinfer/__init__.py (1)
88-91: LGTM! BF16 GEMM APIs properly exposed at package root.The imports of
bmm_bf16andmm_bf16correctly wire the root package API to the BF16 GEMM functionality, aligning with the implementation inflashinfer/gemm/gemm_base.pyand its re-exports inflashinfer/gemm/__init__.py.flashinfer/jit/gemm/__init__.py (2)
22-22: LGTM! BF16 Cutlass generator properly imported.The import correctly exposes the new
gen_gemm_sm100_module_cutlass_bf16function from the core module.
38-38: LGTM! Export added to all.The function is properly included in the public API surface of the jit.gemm module.
flashinfer/jit/gemm/core.py (1)
193-236: LGTM! BF16 Cutlass JIT generator follows established patterns.The function correctly mirrors the existing FP8/FP4 generation workflow, rendering templates for multiple CTA configurations and returning a properly configured JIT spec.
The static analysis tool suggests using list unpacking (
[*nvcc_flags, "-DENABLE_BF16"]) instead of concatenation at line 232, but the current approach is consistent with the rest of the codebase and works correctly.csrc/bf16_gemm_cutlass.cu (5)
40-42: LGTM! Explicit template instantiations are correct.The explicit instantiations for
CutlassBf16GemmRunner<__nv_bfloat16>andCutlassBf16GemmRunner<half>ensure the template implementations are available for linking.
49-58: LGTM! Config selection with proper bounds checking.The
getBf16GemmConfigfunction correctly:
- Lazily initializes a static vector of configs
- Validates the tactic index is within bounds
- Provides a clear error message on out-of-bounds access
60-83: LGTM! Workspace management correctly handles both cases.The
runGemmfunction properly:
- Computes required workspace size
- Allocates temporary workspace if the provided buffer is insufficient
- Reuses the provided workspace when adequate
- Passes correct parameters to the GEMM runner
85-140: LGTM! Input validation and dispatch logic is sound.The
bf16_bmm_implfunction correctly:
- Validates input dtypes
- Handles both 2D (matrix) and 3D (batched) inputs
- Checks dimension compatibility with clear error messages
- Validates output shape and dtype
- Dispatches to the appropriate template instantiation based on output dtype
Note: The past review mentioned a non-contiguous
mat2issue. This is correctly handled in the Python layer (flashinfer/gemm/gemm_base.py) where the caller ensures contiguity before passing to the FFI.
144-161: LGTM! FFI exports are properly defined.The public
bf16_gemmandbf16_gemm_tactic_numfunctions are correctly exposed via TVM FFI macros, enabling Python-side access to the BF16 GEMM functionality.include/flashinfer/gemm/bf16_gemm_template_sm100.h (4)
46-64: LGTM! SM type adapters correctly configured.The
SMTypeAdapterspecializations properly map_1SMand_2SMto their respective:
- Scale factors (1 and 2)
- Epilogue schedules (TmaWarpSpecialized1Sm and TmaWarpSpecialized2Sm)
- Mainloop schedules (KernelTmaWarpSpecialized1SmSm100 and KernelTmaWarpSpecialized2SmSm100)
66-153: LGTM! Launcher correctly configures Cutlass GEMM.The
genericBf16GemmKernelLauncherSm100function properly:
- Defines element types, layouts, and alignments for A, B, C, D
- Handles conditional type selection for BF16 support
- Builds CollectiveEpilogue and CollectiveMainloop with appropriate tile and cluster shapes
- Constructs stride descriptors for batched operation
- Configures fusion arguments (alpha/beta)
183-190: LGTM! Macro correctly instantiates BF16 GEMM templates.The
INSTANCE_BF16_GEMM_TEMPLATE_SM100macro properly expands to explicit template instantiations for the launcher with specified tile and cluster configurations.
154-178: The workspace query mechanism is working as designed.The
if (!A && !B && !D)pattern at lines 154-155 is the established, intentional design across all GEMM implementations (bf16, fp8, fp4). When probing for workspace size,getWorkspaceSizeImpl()explicitly callsdispatchToArch()withnullptrfor all data pointers (A, B, D) along withnullptrforworkspacePtrandworkspaceBytes=0. The data pointer check correctly detects and handles this probe scenario, returning the required workspace size without triggering subsequent validation. CheckingworkspacePtrinstead would not improve robustness, as the pointer is part of the function signature regardless. The pattern's correctness is further validated by the error-handling comments in similar implementations that explicitly document swallowing SMEM constraint errors during configuration probing.Likely an incorrect or invalid review comment.
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (7)
44-45: LGTM! Forward declarations for SM types.The forward declarations of
_1SMand_2SMare correct and used by the SMTypeAdapter specializations.
47-52: LGTM! Launcher declaration is well-formed.The
genericBf16GemmKernelLauncherSm100function signature correctly declares all necessary template parameters and runtime arguments for the SM100 BF16 GEMM launcher.
54-91: LGTM! Cluster shape dispatch covers all supported configurations.The
dispatchGemmClusterShapeSm100function correctly routes to the appropriate launcher based on the cluster shape (1x1x1, 1x2x1, 1x4x1, 2x1x1, 2x2x1), with proper SM type selection (_1SM or _2SM) and error handling for unsupported shapes.
93-125: LGTM! Tile configuration dispatch handles all SM100 tile sizes.The
dispatchToArchfunction correctly:
- Routes to the appropriate tile configuration based on
tile_config_sm100- Swaps A/B matrices and m/n dimensions to match column-major layout requirements
- Throws for unsupported tile configurations
127-134: LGTM! GEMM entry point correctly delegates to arch dispatch.The
CutlassBf16GemmRunner<T>::gemmimplementation properly forwards all parameters to the architecture-specific dispatch function.
136-151: LGTM! Workspace size probing handles SMEM constraint failures.The
getWorkspaceSizeImplfunction correctly:
- Probes all available GEMM configurations
- Catches and silently ignores
std::runtime_errorexceptions when configurations exceed SMEM limits- Returns the maximum workspace size across all valid configurations
The comment "Swallow errors when SMEM exceeds maximum allowed" clearly documents this intentional behavior. Based on learnings from the FP8 implementation, this pattern is acceptable for configuration discovery.
178-202: LGTM! Config generation covers all tile and cluster combinations.The
getConfigsfunction correctly:
- Enumerates all 5 SM100 tile configurations (64x64x128, 64x128x128, 64x256x128, 128x64x128, 128x128x128)
- Combines with all 5 cluster shapes (1x1x1, 1x2x1, 1x4x1, 2x1x1, 2x2x1)
- Creates 25 candidate configurations for autotuning
- Uses AUTO schedule types for flexibility
flashinfer/gemm/gemm_base.py (3)
276-384: Well-structured BF16 MM API with proper backend routing.The
mm_bf16function follows established patterns with proper decorator usage (@backend_requirement,@flashinfer_api), clear documentation, and correct backend selection logic. The integration with the autotuner viabf16_gemm_sm100is well-implemented.
433-510: Clean BMM BF16 implementation ready for backend expansion.The
bmm_bf16function is well-structured with proper validation and workspace management. The design anticipates future cuDNN backend support (as noted in the author's comment at line 438), making it extensible.
762-796: CUTLASS BF16 runner implementation is consistent with FP8 pattern.The
CutlassBf16GemmRunnerfollows the same structure as the FP8 runner, with proper tactic enumeration and tensor unpacking. Theb.transpose(-2, -1)pattern on line 785 is consistent with the FP8 implementation (line 726).
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
|
@aleozlx could you trigger CI now that the checks have passed? (if you want others to review just ignore) |
|
/bot run |
bkryu
left a comment
There was a problem hiding this comment.
Hi @raayandhar, this will be a great addition to FlashInfer.
To facilitate benchmarking, would it be possible to add in this or a subsequent PR, benchmarking support in flashinfer_benchmark.py? A reference benchmark routine for bmm_fp8 can be found here.
One strength of the flashinfer_benchmark.py microbenchmark harness is its ability to output a structured csv for results and compare backends' performances with the same inputs. As we add mm_bf16 and bmm_bf16 APIs, adding them to the benchmark would be helpful in performance tracking.
Yes, I was planning on adding support in additional PR (this week probably), I spoke to @aleozlx about it already |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @tests/gemm/test_mm_bf16.py:
- Line 25: The code calls get_compute_capability(torch.device(device="cuda"))
but references an undefined variable named device; update the call to use a
string literal instead (e.g., torch.device("cuda") or simply "cuda") so
get_compute_capability receives a valid device object/string; locate the call to
get_compute_capability and change the torch.device invocation to remove the
undefined named argument.
🧹 Nitpick comments (2)
flashinfer/gemm/gemm_base.py (2)
1650-1657: Consider deduplicating output dtype validation.The
_validate_bf16_output_dtypefunction (lines 1650-1657) is identical to_validate_fp8_output_dtype(lines 1641-1648). Both validate that output dtype is eithertorch.bfloat16ortorch.float16.♻️ Optional refactor to reduce duplication
-def _validate_bf16_output_dtype(dtype: torch.dtype): - """Validate that the output dtype is either bf16 or fp16.""" - if dtype not in (torch.bfloat16, torch.float16): - raise ValueError( - f"Unsupported output dtype: {dtype}. " - f"Only torch.bfloat16 and torch.float16 are supported for BF16 GEMM operations." - ) +def _validate_gemm_output_dtype(dtype: torch.dtype, operation: str = "GEMM"): + """Validate that the output dtype is either bf16 or fp16.""" + if dtype not in (torch.bfloat16, torch.float16): + raise ValueError( + f"Unsupported output dtype: {dtype}. " + f"Only torch.bfloat16 and torch.float16 are supported for {operation} operations." + ) + +def _validate_fp8_output_dtype(dtype: torch.dtype): + """Validate that the output dtype is either bf16 or fp16.""" + _validate_gemm_output_dtype(dtype, "FP8 GEMM") + +def _validate_bf16_output_dtype(dtype: torch.dtype): + """Validate that the output dtype is either bf16 or fp16.""" + _validate_gemm_output_dtype(dtype, "BF16 GEMM")
3897-3897: Trailing blank line in function.Line 3897 has a blank line that appears to be unintentional within the function body. This is a minor cosmetic issue.
♻️ Clean up trailing blank line
if out is None: out_dtype = out_dtype or torch.bfloat16 out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) - m_grouped_fp8_gemm_nt_contiguous( (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk )
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/gemm/gemm_base.pytests/gemm/test_bmm_bf16.pytests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gemm/test_bmm_bf16.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/gemm/test_mm_bf16.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/gemm/gemm_base.py
🧠 Learnings (13)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/gemm/test_mm_bf16.pyflashinfer/gemm/gemm_base.py
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
flashinfer/gemm/gemm_base.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (2)
tests/gemm/test_mm_bf16.py (5)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
flashinfer(41-125)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
flashinfer(26-60)flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(283-382)flashinfer/utils.py (1)
get_compute_capability(258-261)
flashinfer/gemm/gemm_base.py (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)csrc/bf16_gemm_cutlass.cu (2)
bf16_gemm(144-147)bf16_gemm(144-145)
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py
187-187: Unused function argument: a
(ARG001)
188-188: Unused function argument: b
(ARG001)
189-189: Unused function argument: out
(ARG001)
193-193: Unused function argument: backend
(ARG001)
196-198: Avoid specifying long messages outside the exception class
(TRY003)
200-202: Avoid specifying long messages outside the exception class
(TRY003)
211-211: Unused function argument: a
(ARG001)
212-212: Unused function argument: b
(ARG001)
213-213: Unused function argument: out
(ARG001)
215-215: Unused function argument: bias
(ARG001)
216-216: Unused function argument: pdl
(ARG001)
217-217: Unused function argument: backend
(ARG001)
220-222: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Unused function argument: pdl
(ARG001)
231-231: Unused function argument: out
(ARG001)
232-232: Unused function argument: out_dtype
(ARG001)
233-233: Unused function argument: backend
(ARG001)
236-238: Avoid specifying long messages outside the exception class
(TRY003)
240-242: Avoid specifying long messages outside the exception class
(TRY003)
245-247: Avoid specifying long messages outside the exception class
(TRY003)
254-254: Unused function argument: a
(ARG001)
255-255: Unused function argument: b
(ARG001)
258-258: Unused function argument: out
(ARG001)
259-259: Unused function argument: out_dtype
(ARG001)
260-260: Unused function argument: backend
(ARG001)
353-355: Avoid specifying long messages outside the exception class
(TRY003)
357-359: Avoid specifying long messages outside the exception class
(TRY003)
361-363: Avoid specifying long messages outside the exception class
(TRY003)
387-387: Unused function argument: A
(ARG001)
388-388: Unused function argument: B
(ARG001)
389-389: Unused function argument: out
(ARG001)
391-391: Unused function argument: backend
(ARG001)
401-401: Unused function argument: out
(ARG001)
402-402: Unused function argument: out_dtype
(ARG001)
403-403: Unused function argument: backend
(ARG001)
406-408: Avoid specifying long messages outside the exception class
(TRY003)
410-412: Avoid specifying long messages outside the exception class
(TRY003)
419-419: Unused function argument: A
(ARG001)
420-420: Unused function argument: B
(ARG001)
421-421: Unused function argument: out
(ARG001)
422-422: Unused function argument: out_dtype
(ARG001)
423-423: Unused function argument: backend
(ARG001)
444-444: Unused function argument: backend
(ARG001)
492-494: Avoid specifying long messages outside the exception class
(TRY003)
496-498: Avoid specifying long messages outside the exception class
(TRY003)
500-502: Avoid specifying long messages outside the exception class
(TRY003)
768-768: Unused method argument: inputs
(ARG002)
769-769: Unused method argument: profile
(ARG002)
777-777: Unused method argument: do_preparation
(ARG002)
778-778: Unused method argument: kwargs
(ARG002)
1066-1068: Avoid specifying long messages outside the exception class
(TRY003)
1070-1072: Avoid specifying long messages outside the exception class
(TRY003)
1074-1076: Avoid specifying long messages outside the exception class
(TRY003)
1653-1656: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/gemm/gemm_base.py (1)
370-379: Verify backend parameter handling in heuristic calls.When
backend="cutlass"orbackend="tgv", the code calls_heuristic_func_mm_bf16with hardcoded values (None, Falsefor cutlass on line 372, andbias, pdlfor tgv on line 376). However, the function signature and validation should have already rejected invalid combinations (e.g., cutlass with bias).This pattern seems intentional but could be clarified. If a user somehow bypasses validation and calls with
backend="cutlass", bias=<tensor>, line 372 silently ignores the bias rather than raising an error.Based on learnings and the
@backend_requirementdecorator behavior, the requirement functions should prevent invalid combinations from reaching this code. However, consider whether explicit parameter passing (rather than hardcoding) would be clearer:elif backend == "cutlass": backends = _heuristic_func_mm_bf16( ["cutlass"], a, b, bias, pdl, out, out_dtype, backend )This would make the heuristic function responsible for filtering, which it already does.
|
[FAILED] Pipeline #41300448: 4/20 passed |
|
/bot run |
|
[CANCELING] Pipeline #41366487: canceled |
bkryu
left a comment
There was a problem hiding this comment.
LGTM. Unit tests says cancelled passed on all key SKUs.
| if out.dtype != out_dtype: | ||
| raise ValueError( | ||
| f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." | ||
| ) |
There was a problem hiding this comment.
[nit] can you move these checks to _check_mm_bf16_problem_size?
There was a problem hiding this comment.
Yeah I'll do that in my cuDNN PR.
@raayandhar one of the cutlass maintainers mentioned about a transpose trick you can do when M is small NVIDIA/cutlass#2923 (comment) - with stream_k and split_k - wondering if you tried that? I'm currently taking a look at the block-wise fp8 kernel right now NVIDIA/cutlass#2923 <- lmk if you had any context that would be useful for me here. The cutlass backend performs pretty bad for smaller batch sizes / shapes. |
Unfortunately no, I did not try that trick. But its something I could try in the future; I'll be following that issue thread. Yeah I guess if you looked at the benchmark script from earlier in this PR I also observed that at lower batch sizes we were not performing very well. But unfortunately I don't have much more context than that. Maybe @aleozlx @bkryu might know more? |
📌 Description
This issue was opened a little while ago (#1974) and I finally got a chance to tackle it. Feature request for BF16 GEMM. I decided to try and implement using CUTLASS backend. The issue poster was using B200 so I implemented for B200 (SM100) as well.
🔍 Related Issues
Feature request: #1974
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.