Skip to content

Commit 7c014be

Browse files
authored
Merge branch 'main' into fmhav2-bench
2 parents 42fc839 + e1e6714 commit 7c014be

71 files changed

Lines changed: 5199 additions & 602 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

benchmarks/routines/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,6 +1305,7 @@ def run_backend_wrapper(
13051305
batch_size=batch_size,
13061306
cum_seq_lens_q=qo_indptr,
13071307
cum_seq_lens_kv=kv_indptr,
1308+
causal=causal,
13081309
kv_cache_sf=kv_cache_sf,
13091310
)
13101311
elif backend == "cudnn-native":

benchmarks/routines/gemm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def parse_gemm_args(line, parser):
148148
"trtllm",
149149
"cutlass",
150150
"tgv",
151+
"cublaslt",
151152
"cute-dsl",
152153
"b12x",
153154
"auto",
@@ -1585,7 +1586,7 @@ def testMmBf16(args):
15851586
use_pdl = getattr(args, "enable_pdl", False)
15861587
is_cuda_graph_compatible = not args.no_cuda_graph
15871588
run_refcheck = args.refcheck
1588-
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"]
1589+
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]
15891590
res = []
15901591

15911592
out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
@@ -1650,7 +1651,7 @@ def testMmBf16(args):
16501651
return res
16511652

16521653
def run_backend(backend, a, b, bias, use_pdl, out_dtype):
1653-
if backend in ["cudnn", "cutlass", "tgv", "auto"]:
1654+
if backend in ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]:
16541655
return flashinfer.mm_bf16(
16551656
a=a,
16561657
b=b,

csrc/bmm_fp8.cu

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
4949
auto stream = get_stream(A.device());
5050

5151
auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
52-
workspace_buffer.data_ptr(), workspace_buffer.numel(),
52+
workspace_buffer.data_ptr(),
53+
workspace_buffer.numel() * get_element_size(workspace_buffer),
5354
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
5455
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
5556
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
@@ -61,3 +62,91 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
6162
});
6263
});
6364
}
65+
66+
int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
67+
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
68+
TensorView algo_buffer) {
69+
CHECK_CUDA(A);
70+
CHECK_CUDA(B);
71+
CHECK_CUDA(D);
72+
CHECK_DIM(3, A);
73+
CHECK_DIM(3, B);
74+
CHECK_DIM(3, D);
75+
CHECK_CONTIGUOUS(algo_buffer);
76+
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
77+
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
78+
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
79+
<< "Result tensor has incorrect shape";
80+
81+
int64_t result = 0;
82+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
83+
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
84+
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
85+
auto batch_size = A.size(0);
86+
auto m = A.size(1);
87+
auto k = A.size(2);
88+
auto n = B.size(2);
89+
90+
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
91+
ffi::CUDADeviceGuard device_guard(A.device().device_id);
92+
93+
int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
94+
flashinfer::bmm_fp8::kAlgoBytes);
95+
result = flashinfer::bmm_fp8::get_fp8_algorithms<b_type, a_type, d_type>(
96+
batch_size, n, m, k, static_cast<float*>(B_scale.data_ptr()),
97+
static_cast<float*>(A_scale.data_ptr()),
98+
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
99+
algo_buffer.data_ptr(), max_algos);
100+
return true;
101+
});
102+
});
103+
});
104+
return static_cast<int64_t>(result);
105+
}
106+
107+
void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
108+
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
109+
TensorView algo_buffer, int64_t algo_idx) {
110+
CHECK_CUDA(A);
111+
CHECK_CUDA(B);
112+
CHECK_CUDA(D);
113+
CHECK_DIM(3, A);
114+
CHECK_DIM(3, B);
115+
CHECK_DIM(3, D);
116+
CHECK_CONTIGUOUS(algo_buffer);
117+
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
118+
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
119+
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
120+
<< "Result tensor has incorrect shape";
121+
122+
int64_t max_algos =
123+
algo_buffer.numel() * get_element_size(algo_buffer) / flashinfer::bmm_fp8::kAlgoBytes;
124+
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
125+
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";
126+
127+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
128+
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
129+
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
130+
auto batch_size = A.size(0);
131+
auto m = A.size(1);
132+
auto k = A.size(2);
133+
auto n = B.size(2);
134+
135+
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
136+
ffi::CUDADeviceGuard device_guard(A.device().device_id);
137+
auto stream = get_stream(A.device());
138+
139+
auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo<b_type, a_type, d_type>(
140+
workspace_buffer.data_ptr(),
141+
workspace_buffer.numel() * get_element_size(workspace_buffer),
142+
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
143+
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
144+
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
145+
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
146+
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
147+
<< "bmm_fp8_run_with_algo failed: " << cublasGetStatusString(status);
148+
return true;
149+
});
150+
});
151+
});
152+
}

csrc/flashinfer_gemm_binding.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,19 @@
1919
void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale,
2020
TensorView workspace_buffer, int64_t cublas_handle);
2121

22+
int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
23+
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
24+
TensorView algo_buffer);
25+
26+
void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
27+
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
28+
TensorView algo_buffer, int64_t algo_idx);
29+
2230
void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr,
2331
TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld,
2432
TensorView y_ld, TensorView empty_x_data, bool weight_column_major);
2533

2634
TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM);
2735
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8);
36+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_get_algos, bmm_fp8_get_algos);
37+
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_run_with_algo, bmm_fp8_run_with_algo);

csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -889,12 +889,14 @@ struct Gmem_tile_paged_kv {
889889
// Do not load/store if the thread is in the padded area
890890
col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG;
891891

892-
int64_t kv_stride_in_bytes =
893-
qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes;
894-
// The head offset.
895-
head_stride_in_bytes_ = (int64_t)(binfo.bidh / params.h_q_per_kv) * kv_stride_in_bytes;
896-
// When V is padded (like MLA), we cannot use VALID_BYTES_PER_ROW
897-
token_stride_in_bytes_ = kv_stride_in_bytes >> paged_kv_log2_block_size_;
892+
// The head stride in bytes.
893+
int64_t head_stride_in_bytes =
894+
qkv_offset == 1 ? params.k_stride_in_bytes_2 : params.v_stride_in_bytes_2;
895+
// The head offset in bytes.
896+
head_offset_in_bytes_ = (binfo.bidh / params.h_q_per_kv) * head_stride_in_bytes;
897+
898+
// The token stride in bytes.
899+
token_stride_in_bytes_ = qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes;
898900

899901
// Take the CTA offset to modify the sequence length.
900902
// Actually we don't need that for flash attention.
@@ -918,7 +920,7 @@ struct Gmem_tile_paged_kv {
918920
void const* ptrs[LDGS];
919921

920922
// Offset for the new paged kv pointer.
921-
uint64_t const head_col_in_bytes = head_stride_in_bytes_ + col_in_bytes_;
923+
uint64_t const head_col_in_bytes = head_offset_in_bytes_ + col_in_bytes_;
922924

923925
// Update paged_kv ptr for each LDG (reuse is possible).
924926
#pragma unroll
@@ -984,9 +986,9 @@ struct Gmem_tile_paged_kv {
984986
int row_;
985987
int64_t col_in_bytes_;
986988
// Keep track of the head offset.
987-
int64_t head_stride_in_bytes_;
989+
int64_t head_offset_in_bytes_;
988990
// // for DeepSeek MLA, the stride of V tokens != VALID_BYTES_PER_ROW
989-
int32_t token_stride_in_bytes_;
991+
int64_t token_stride_in_bytes_;
990992
// The sequence length.
991993
int actual_seqlen_;
992994
// The past sequence length (kv_seqlen - q_seqlen) considering chunked context.

csrc/fmha_v2/fmha/warpspec/dma.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -795,13 +795,14 @@ struct DMA {
795795
uint32_t tensor_size_v[4] = {dv, tokens_per_block, h_kv, INT_MAX};
796796

797797
uint64_t tensor_stride_k[3];
798-
tensor_stride_k[0] = params.k_stride_in_bytes / tokens_per_block; // d
799-
tensor_stride_k[1] = params.k_stride_in_bytes; // d * 64
798+
tensor_stride_k[0] = params.k_stride_in_bytes;
799+
tensor_stride_k[1] = params.k_stride_in_bytes_2;
800800
tensor_stride_k[2] = params.paged_kv_cache.mBytesPerBlock;
801801
uint64_t tensor_stride_v[3];
802802
// we cannot use dv * Kernel_traits::ELEMENT_BYTES because V may be padded (MLA)
803-
tensor_stride_v[0] = params.v_stride_in_bytes / tokens_per_block; // dv
804-
tensor_stride_v[1] = params.v_stride_in_bytes; // dv * 64
803+
// use the values given by caller
804+
tensor_stride_v[0] = params.v_stride_in_bytes;
805+
tensor_stride_v[1] = params.v_stride_in_bytes_2;
805806
tensor_stride_v[2] = params.paged_kv_cache.mBytesPerBlock;
806807

807808
char* kv_ptr = reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr);

csrc/fmha_v2/fused_multihead_attention.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,13 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
237237
int64_t q_stride_in_bytes;
238238
int64_t k_stride_in_bytes;
239239
int64_t v_stride_in_bytes;
240+
// Paged KV uses 4D tensor, the tensor size is:
241+
// HND = [num_pages, H, page_size, D] or NHD = [num_pages, page_size, H, D]
242+
// so need another pair of stride.
243+
// x_stride_in_bytes means the stride of tensor_size[1]
244+
// x_stride_in_bytes_2 means the stride of tensor_size[2]
245+
int64_t k_stride_in_bytes_2;
246+
int64_t v_stride_in_bytes_2;
240247

241248
// Paged KV load.
242249
int blocks_per_tma_load;

csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,12 @@ struct Fused_multihead_attention_params_v2 {
177177
uint32_t* skip_softmax_total_blocks;
178178
uint32_t* skip_softmax_skipped_blocks;
179179
#endif
180+
181+
// Paged KV uses 4D tensor, the tensor size is:
182+
// HND = [num_pages, H, page_size, D] or NHD = [num_pages, page_size, H, D]
183+
// so need another pair of stride.
184+
// x_stride_in_bytes means the stride of tensor_size[1]
185+
// x_stride_in_bytes_2 means the stride of tensor_size[2]
186+
int64_t k_stride_in_bytes_2;
187+
int64_t v_stride_in_bytes_2;
180188
};

csrc/fmha_v2_run.cu

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ static inline void set_params(
5050
// types
5151
Data_type data_type, Data_type acc_type, Data_type output_dtype,
5252
// attention input layout
53-
Attention_input_layout input_layout,
53+
Attention_input_layout input_layout, const bool is_paged_hnd,
5454
// sizes
5555
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv,
5656
const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads,
@@ -119,8 +119,21 @@ static inline void set_params(
119119
get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type),
120120
paged_kv_pool_ptr);
121121
params.paged_kv_cache.mBlockOffsets = paged_block_offsets;
122-
params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type);
123-
params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type);
122+
// FMHA kernels always access the K/V tensor in 4D coordinate [num_pages, H_kv, page_size, D].
123+
// The layout of HND or NHD is implemented by tensor strides to get the correct memory
124+
// address. 4D tensor strides of HND: [block_size, page_size * D, D ,1] 4D tensor strides of
125+
// NHD: [block_size, D, H_kv * D, 1]
126+
if (is_paged_hnd) {
127+
params.k_stride_in_bytes = get_size_in_bytes(d, data_type);
128+
params.v_stride_in_bytes = get_size_in_bytes(dv, data_type);
129+
params.k_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * d, data_type);
130+
params.v_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * dv, data_type);
131+
} else {
132+
params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type);
133+
params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type);
134+
params.k_stride_in_bytes_2 = get_size_in_bytes(d, data_type);
135+
params.v_stride_in_bytes_2 = get_size_in_bytes(dv, data_type);
136+
}
124137
} else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) {
125138
// Layout [B, S, H_kv, D].
126139
params.k_ptr = k_d;
@@ -247,10 +260,15 @@ static inline void determine_launch_params(
247260
launch_params.multi_processor_count = props.multiProcessorCount;
248261
launch_params.device_l2_cache_size = props.l2CacheSize;
249262

263+
#if 0
250264
// threshold for adopting flash attention or warp_specialized kernels.
251265
launch_params.flash_attention =
252266
(data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
253267
(s >= 16 && d >= 16) && !force_non_flash_attention;
268+
#else
269+
// Currently only flash attention kernels are generated in FlashInfer
270+
launch_params.flash_attention = true;
271+
#endif
254272

255273
// enable warp_speialized kernels when s >= 512 on hopper
256274
// note that warp_speialized kernels need flash attention + tma
@@ -304,11 +322,18 @@ static inline Attention_mask_type string_to_mask_type(const std::string& s) {
304322
return Attention_mask_type::CAUSAL; // default
305323
}
306324

307-
static inline Attention_input_layout string_to_input_layout(const std::string& s) {
325+
static inline Attention_input_layout string_to_input_layout(const std::string& s,
326+
bool& is_paged_hnd) {
327+
is_paged_hnd = false;
308328
if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV;
309329
if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV;
310-
if (s == "q_paged_kv_nhd") return Attention_input_layout::Q_PAGED_KV;
311-
if (s == "q_paged_kv_hnd") return Attention_input_layout::Q_PAGED_KV;
330+
if (s == "q_paged_kv_nhd") {
331+
return Attention_input_layout::Q_PAGED_KV;
332+
}
333+
if (s == "q_paged_kv_hnd") {
334+
is_paged_hnd = true;
335+
return Attention_input_layout::Q_PAGED_KV;
336+
}
312337
if (s == "separate_q_k_v") return Attention_input_layout::SEPARATE_Q_K_V;
313338
throw std::invalid_argument("Unsupported input_layout: " + s);
314339
}
@@ -330,7 +355,8 @@ void fmha_v2_run(
330355
float skip_softmax_threshold_scale_factor,
331356
Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum)
332357
Optional<ffi::TensorView> sinks) {
333-
Attention_input_layout input_layout = string_to_input_layout(input_layout_str);
358+
bool is_paged_hnd;
359+
Attention_input_layout input_layout = string_to_input_layout(input_layout_str, is_paged_hnd);
334360
Attention_mask_type attention_mask_type = string_to_mask_type(mask_mode_str);
335361
Data_type output_dtype = dltype_to_data_type(o.dtype());
336362
// Get device properties
@@ -360,9 +386,12 @@ void fmha_v2_run(
360386
d = q.shape()[3]; // head_dim_qk
361387
dv = q.shape()[3]; // head_dim_v (same as d for standard attention)
362388
} else if (input_layout == Attention_input_layout::Q_PAGED_KV) {
363-
// q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D]
389+
// q is 3D: [total_tokens, H, D]
364390
h = q.shape()[1];
365-
h_kv = k.shape()[1];
391+
// k/v are 4D paged:
392+
// HND: [num_pages, H_kv, page_size, D]
393+
// NHD: [num_pages, page_size, H_kv, D]
394+
h_kv = k.shape()[is_paged_hnd ? 1 : 2];
366395
d = q.shape()[2];
367396
dv = v.shape()[3];
368397
} else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) {

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ using namespace tensorrt_llm::kernels;
7171
using namespace tensorrt_llm::common;
7272

7373
namespace tensorrt_llm::kernels::cutlass_kernels {
74+
75+
constexpr int CVT_ELTS_PER_THREAD = 8;
76+
7477
/**
7578
* Takes the input maps and prepares the expanded maps for min latency
7679
* @param num_active_experts_per_node: Number of active experts on current node

0 commit comments

Comments
 (0)