Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
044dc7e
add cublaslt to mm_bf16
vadiklyutiy Mar 28, 2026
f6073b4
fixes
vadiklyutiy Mar 28, 2026
40d0ac3
perf fix
vadiklyutiy Mar 29, 2026
17233aa
fix
vadiklyutiy Mar 29, 2026
bdeaa6e
Enable multi-tactic autotuning for CublasFp8, CudnnFp8, and CudnnMxfp…
vadiklyutiy Mar 29, 2026
3c4d422
add backed to tests
vadiklyutiy Mar 29, 2026
214cb05
fix
vadiklyutiy Mar 29, 2026
90d9511
CR fixes
vadiklyutiy Mar 29, 2026
e633cc7
support fp16,fp32 output
vadiklyutiy Mar 29, 2026
98dde99
CR fixes
vadiklyutiy Mar 29, 2026
4616733
CR fixes
vadiklyutiy Mar 29, 2026
37e2abf
CR fixes
vadiklyutiy Mar 30, 2026
cac20a4
coderabbit review fixes
vadiklyutiy Mar 30, 2026
9e82e66
coderabbit review fixes
vadiklyutiy Mar 30, 2026
b5f957d
Add get_cache_key_extras and fix out docstring
vadiklyutiy Mar 30, 2026
989891b
Merge branch 'main' into mm-bf16-cublaslt
vadiklyutiy Apr 2, 2026
79dbf10
Add rationale comment for __hash__ robustification
vadiklyutiy Apr 3, 2026
ae8f830
Raise error for invalid tactic index
vadiklyutiy Apr 3, 2026
b63f460
Replace build_all with policy for mxfp8 graph
vadiklyutiy Apr 3, 2026
5e8367b
Unify algo cache key with get_cache_key_extras
vadiklyutiy Apr 3, 2026
ef9941a
Use FP32 compute for FP16 cuBLASLt output
vadiklyutiy Apr 7, 2026
ccbca2b
Merge branch 'main' into mm-bf16-cublaslt
vadiklyutiy Apr 8, 2026
40e5d61
Merge branch 'main' into mm-bf16-cublaslt
vadiklyutiy Apr 22, 2026
0ccf6c1
fix misprint
vadiklyutiy Apr 22, 2026
c82cb9f
Merge branch 'main' into mm-bf16-cublaslt
dhiraj113 Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ def parse_gemm_args(line, parser):
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "cute-dsl", "auto"],
choices=[
"cudnn",
"cublas",
"trtllm",
"cutlass",
"tgv",
"cublaslt",
"cute-dsl",
"auto",
],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
Expand Down Expand Up @@ -1553,7 +1562,7 @@ def testMmBf16(args):
use_pdl = getattr(args, "enable_pdl", False)
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"]
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]
res = []

out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
Expand Down Expand Up @@ -1618,7 +1627,7 @@ def testMmBf16(args):
return res

def run_backend(backend, a, b, bias, use_pdl, out_dtype):
if backend in ["cudnn", "cutlass", "tgv", "auto"]:
if backend in ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]:
Comment thread
vadiklyutiy marked this conversation as resolved.
return flashinfer.mm_bf16(
a=a,
b=b,
Expand Down
91 changes: 90 additions & 1 deletion csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
auto stream = get_stream(A.device());

auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(), workspace_buffer.numel(),
workspace_buffer.data_ptr(),
workspace_buffer.numel() * get_element_size(workspace_buffer),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
Expand All @@ -61,3 +62,91 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
});
});
}

int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer) {
CHECK_CUDA(A);
CHECK_CUDA(B);
CHECK_CUDA(D);
CHECK_DIM(3, A);
CHECK_DIM(3, B);
CHECK_DIM(3, D);
CHECK_CONTIGUOUS(algo_buffer);
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
<< "Result tensor has incorrect shape";

int64_t result = 0;
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(A.device().device_id);

int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::bmm_fp8::kAlgoBytes);
result = flashinfer::bmm_fp8::get_fp8_algorithms<b_type, a_type, d_type>(
batch_size, n, m, k, static_cast<float*>(B_scale.data_ptr()),
static_cast<float*>(A_scale.data_ptr()),
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
algo_buffer.data_ptr(), max_algos);
return true;
});
});
});
return static_cast<int64_t>(result);
}

void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx) {
CHECK_CUDA(A);
CHECK_CUDA(B);
CHECK_CUDA(D);
CHECK_DIM(3, A);
CHECK_DIM(3, B);
CHECK_DIM(3, D);
CHECK_CONTIGUOUS(algo_buffer);
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
<< "Result tensor has incorrect shape";

int64_t max_algos =
algo_buffer.numel() * get_element_size(algo_buffer) / flashinfer::bmm_fp8::kAlgoBytes;
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(A.device().device_id);
auto stream = get_stream(A.device());

auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo<b_type, a_type, d_type>(
workspace_buffer.data_ptr(),
workspace_buffer.numel() * get_element_size(workspace_buffer),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
<< "bmm_fp8_run_with_algo failed: " << cublasGetStatusString(status);
return true;
});
});
});
}
10 changes: 10 additions & 0 deletions csrc/flashinfer_gemm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@
void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale,
TensorView workspace_buffer, int64_t cublas_handle);

int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer);

void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx);

void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr,
TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld,
TensorView y_ld, TensorView empty_x_data, bool weight_column_major);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_get_algos, bmm_fp8_get_algos);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_run_with_algo, bmm_fp8_run_with_algo);
126 changes: 126 additions & 0 deletions csrc/mm_bf16_cublaslt.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_bf16.h>
#include <driver_types.h>

#include <flashinfer/gemm/mm_bf16_cublaslt.cuh>

#include "tvm_ffi_utils.h"

namespace {

cudaDataType_t get_d_type(DLDataType dtype) {
switch (encode_dlpack_dtype(dtype)) {
case bfloat16_code:
return CUDA_R_16BF;
case float16_code:
return CUDA_R_16F;
case float32_code:
return CUDA_R_32F;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32.";
return CUDA_R_16BF;
}
}

} // namespace

// Serialize all heuristic algorithms into a CPU uint8 tensor for caching.
// algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes.
// Returns number of algorithms written.
int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);

int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);

TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";

cudaDataType_t d_type = get_d_type(out.dtype());

ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes);
return static_cast<int64_t>(flashinfer::mm_bf16_cublaslt::get_algorithms(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
algo_buffer.data_ptr(), max_algos));
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// Run matmul using a pre-cached algorithm — zero heuristic overhead.
void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);

int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);

TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";

int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes;
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto stream = get_stream(mat1.device());
cudaDataType_t d_type = get_d_type(out.dtype());

auto status = flashinfer::mm_bf16_cublaslt::run_with_algo(
static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()),
out.data_ptr(), static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer),
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
<< "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo);
3 changes: 3 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
gen_gemm_sm100_module_cutlass_mxfp8,
gen_gemm_sm120_module,
gen_gemm_sm120_module_cutlass_fp4,
gen_mm_bf16_cublaslt_module,
gen_tgv_gemm_sm10x_module,
gen_trtllm_gen_gemm_module,
gen_trtllm_low_latency_gemm_module,
Expand Down Expand Up @@ -511,6 +512,8 @@ def gen_all_modules(
)
jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True))
jit_specs.append(gen_moe_utils_module())
if has_sm100 or has_sm103:
jit_specs.append(gen_mm_bf16_cublaslt_module())
if has_sm103:
jit_specs.append(gen_fp4_quantization_sm103_module())
jit_specs.append(gen_cutlass_fused_moe_sm103_module())
Expand Down
11 changes: 10 additions & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,16 @@ def forward(
raise NotImplementedError

def __hash__(self):
return hash(tuple(self.__dict__.values()))
hashable_vals = []
Comment thread
dhiraj113 marked this conversation as resolved.
for k, v in self.__dict__.items():
if k.endswith("_cache"):
continue
try:
hash(v)
hashable_vals.append(v)
except TypeError:
hashable_vals.append(id(v))
return hash(tuple(hashable_vals))
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@contextlib.contextmanager
Expand Down
Loading