Skip to content

Commit 2062dec

Browse files
raayandharbkryu
andauthored
feat: BF16 GEMM for SM100, including CUTLASS, TGV backends (#2070)
<!-- .github/pull_request_template.md --> ## 📌 Description This issue was opened a little while ago (#1974) and I finally got a chance to tackle it. Feature request for BF16 GEMM. I decided to try and implement using CUTLASS backend. The issue poster was using B200 so I implemented for B200 (SM100) as well. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues Feature request: #1974 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added high-performance BF16 matrix-multiply APIs (mm_bf16, bmm_bf16) with selectable backend, workspace management, autotuning support, and a runtime tactic query. * Integrated Cutlass-based BF16 GEMM runner, JIT generation for SM100 BF16 kernels, and public runtime entry points for native execution. * **Documentation** * Added BF16 GEMM docs and autosummary entries. * **Tests** * Added unit tests validating mm_bf16 and bmm_bf16 on supported GPUs. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: raayandhar <raayan.dhar@gmail.com> Signed-off-by: Raayan Dhar raayan.dhar@gmail.com <raayan.dhar@gmail.com> Co-authored-by: Brian Ryu <bryu@nvidia.com>
1 parent ed01158 commit 2062dec

13 files changed

Lines changed: 1269 additions & 12 deletions

File tree

csrc/bf16_gemm_cutlass.cu

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright (c) 2025, FlashInfer.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <cuda_fp16.h>
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
#include <functional>
22+
#include <type_traits>
23+
#include <vector>
24+
25+
#include "flashinfer/gemm/bf16_gemm_cutlass.h"
26+
#include "flashinfer/gemm/bf16_gemm_cutlass_template.h"
27+
#include "flashinfer/gemm/cutlass_gemm_configs.h"
28+
#include "tvm_ffi_utils.h"
29+
30+
using flashinfer::gemm::ClusterShape;
31+
using flashinfer::gemm::CutlassBf16GemmRunner;
32+
using flashinfer::gemm::CutlassBf16GemmRunnerInterface;
33+
using flashinfer::gemm::CutlassGemmConfig;
34+
using flashinfer::gemm::CutlassTileConfigSM100;
35+
using flashinfer::gemm::EpilogueScheduleType;
36+
using flashinfer::gemm::MainloopScheduleType;
37+
38+
namespace flashinfer {
39+
namespace gemm {
40+
template class CutlassBf16GemmRunner<__nv_bfloat16>;
41+
template class CutlassBf16GemmRunner<half>;
42+
} // namespace gemm
43+
} // namespace flashinfer
44+
45+
namespace torch_ext {
46+
47+
namespace {
48+
49+
CutlassGemmConfig getBf16GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) {
50+
auto getCutlassBf16GemmConfigs = []() {
51+
CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner;
52+
return gemmRunner.getConfigs();
53+
};
54+
static std::vector<CutlassGemmConfig> globalConfigs = getCutlassBf16GemmConfigs();
55+
TVM_FFI_ICHECK(tactic >= 0 && tactic < static_cast<int64_t>(globalConfigs.size()))
56+
<< "tactic must be between 0 and " << globalConfigs.size();
57+
return globalConfigs[tactic];
58+
}
59+
60+
template <typename T>
61+
void runGemm(TensorView out, TensorView mat1, TensorView mat2, int64_t m, int64_t n, int64_t k,
62+
int64_t b, CutlassGemmConfig const& gemmConfig, TensorView workspace_buffer) {
63+
CutlassBf16GemmRunner<T> gemmRunner;
64+
65+
int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k);
66+
int64_t const provided_workspace_size =
67+
workspace_buffer.numel() * get_element_size(workspace_buffer);
68+
69+
auto runKernel = [&](void* workspace) {
70+
gemmRunner.gemm(static_cast<__nv_bfloat16*>(mat1.data_ptr()),
71+
static_cast<__nv_bfloat16*>(mat2.data_ptr()), out.data_ptr(), m, n, k, b,
72+
gemmConfig, static_cast<char*>(workspace), required_workspace_size,
73+
get_stream(mat1.device()));
74+
};
75+
76+
if (provided_workspace_size < required_workspace_size) {
77+
Tensor new_workspace =
78+
alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device());
79+
runKernel(new_workspace.data_ptr());
80+
} else {
81+
runKernel(workspace_buffer.data_ptr());
82+
}
83+
}
84+
85+
void bf16_bmm_impl(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer,
86+
int64_t tactic) {
87+
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
88+
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
89+
90+
int64_t m, n, k, b;
91+
if (mat1.ndim() == 2) {
92+
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix";
93+
TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1))
94+
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1)
95+
<< " and " << mat2.size(0) << "x" << mat2.size(1) << ")";
96+
m = mat1.size(0);
97+
n = mat2.size(0);
98+
k = mat2.size(1);
99+
b = 1;
100+
} else if (mat1.ndim() == 3) {
101+
TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices";
102+
TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size ("
103+
<< mat1.size(0) << " and " << mat2.size(0) << ")";
104+
TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2))
105+
<< "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2)
106+
<< " and " << mat2.size(1) << "x" << mat2.size(2) << ")";
107+
m = mat1.size(1);
108+
n = mat2.size(1);
109+
k = mat2.size(2);
110+
b = mat1.size(0);
111+
} else {
112+
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices";
113+
}
114+
115+
if (tactic == -1) {
116+
tactic = 0;
117+
}
118+
auto config = getBf16GemmConfig(m, n, k, tactic);
119+
120+
std::vector<int64_t> out_shape =
121+
mat1.ndim() == 2 ? std::vector<int64_t>{m, n} : std::vector<int64_t>{b, m, n};
122+
TVM_FFI_ICHECK_EQ(out.ndim(), static_cast<int>(out_shape.size()))
123+
<< "out must have " << out_shape.size() << " dimensions, but got " << out.ndim();
124+
for (int i = 0; i < static_cast<int>(out_shape.size()); ++i) {
125+
TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i])
126+
<< "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got "
127+
<< out.size(i);
128+
}
129+
130+
switch (encode_dlpack_dtype(out.dtype())) {
131+
case float16_code:
132+
runGemm<half>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
133+
break;
134+
case bfloat16_code:
135+
runGemm<__nv_bfloat16>(out, mat1, mat2, m, n, k, b, config, workspace_buffer);
136+
break;
137+
default:
138+
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of fp16/bf16.";
139+
}
140+
}
141+
142+
} // namespace
143+
144+
void bf16_gemm(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer,
145+
int64_t tactic) {
146+
bf16_bmm_impl(mat1, mat2, out, workspace_buffer, tactic);
147+
}
148+
149+
int64_t bf16_gemm_tactic_num() {
150+
auto getCutlassConfigs = []() {
151+
CutlassBf16GemmRunner<__nv_bfloat16> gemmRunner;
152+
return gemmRunner.getConfigs();
153+
};
154+
static int64_t totalTactics = getCutlassConfigs().size();
155+
return totalTactics;
156+
}
157+
158+
} // namespace torch_ext
159+
160+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm, torch_ext::bf16_gemm);
161+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bf16_gemm_tactic_num, torch_ext::bf16_gemm_tactic_num);

csrc/bf16_gemm_cutlass.jinja

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright (c) 2025, FlashInfer.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "flashinfer/gemm/bf16_gemm_template_sm100.h"
18+
19+
namespace flashinfer {
20+
namespace gemm {
21+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM);
22+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM);
23+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM);
24+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM);
25+
INSTANCE_BF16_GEMM_TEMPLATE_SM100({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM);
26+
} // namespace gemm
27+
} // namespace flashinfer

docs/api/gemm.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ flashinfer.gemm
77

88
This module provides a set of GEMM operations.
99

10+
BF16 GEMM
11+
---------
12+
13+
.. autosummary::
14+
:toctree: ../generated
15+
16+
mm_bf16
17+
bmm_bf16
18+
1019
FP4 GEMM
1120
--------
1221

flashinfer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@
8686
)
8787
from .gdn_prefill import chunk_gated_delta_rule as chunk_gated_delta_rule
8888
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
89+
from .gemm import bmm_bf16 as bmm_bf16
8990
from .gemm import bmm_fp8 as bmm_fp8
9091
from .gemm import bmm_mxfp8 as bmm_mxfp8
92+
from .gemm import mm_bf16 as mm_bf16
9193
from .gemm import mm_fp4 as mm_fp4
9294
from .gemm import mm_fp8 as mm_fp8
9395
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100

flashinfer/gemm/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper
2+
from .gemm_base import bmm_bf16 as bmm_bf16
23
from .gemm_base import bmm_fp8 as bmm_fp8
34
from .gemm_base import bmm_mxfp8 as bmm_mxfp8
5+
from .gemm_base import mm_bf16 as mm_bf16
46
from .gemm_base import mm_fp4 as mm_fp4
57
from .gemm_base import mm_fp8 as mm_fp8
68
from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100
@@ -22,8 +24,10 @@
2224

2325
__all__ = [
2426
"SegmentGEMMWrapper",
27+
"bmm_bf16",
2528
"bmm_fp8",
2629
"bmm_mxfp8",
30+
"mm_bf16",
2731
"mm_fp4",
2832
"mm_fp8",
2933
"tgv_gemm_sm100",

0 commit comments

Comments
 (0)