-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathmm_bf16_cublaslt.cu
More file actions
126 lines (108 loc) · 4.93 KB
/
mm_bf16_cublaslt.cu
File metadata and controls
126 lines (108 loc) · 4.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <driver_types.h>
#include <flashinfer/gemm/mm_bf16_cublaslt.cuh>
#include "tvm_ffi_utils.h"
namespace {
cudaDataType_t get_d_type(DLDataType dtype) {
switch (encode_dlpack_dtype(dtype)) {
case bfloat16_code:
return CUDA_R_16BF;
case float16_code:
return CUDA_R_16F;
case float32_code:
return CUDA_R_32F;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32.";
return CUDA_R_16BF;
}
}
} // namespace
// Serialize all heuristic algorithms into a CPU uint8 tensor for caching.
// algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes.
// Returns number of algorithms written.
int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);
int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);
TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";
cudaDataType_t d_type = get_d_type(out.dtype());
ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes);
return static_cast<int64_t>(flashinfer::mm_bf16_cublaslt::get_algorithms(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
algo_buffer.data_ptr(), max_algos));
}
// Run matmul using a pre-cached algorithm — zero heuristic overhead.
void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);
int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);
TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";
int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes;
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto stream = get_stream(mat1.device());
cudaDataType_t d_type = get_d_type(out.dtype());
auto status = flashinfer::mm_bf16_cublaslt::run_with_algo(
static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()),
out.data_ptr(), static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer),
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
<< "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status);
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo);