Skip to content

Commit 789ece9

Browse files
committed
Integration of SageAttn & QkBf16PvFp8 Attn
Signed-off-by: Ruqing Xu <7891482+xrq-phys@users.noreply.github.com>
1 parent a98623e commit 789ece9

8 files changed

Lines changed: 637 additions & 78 deletions

File tree

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,25 @@ enum class TllmPagedAttentionMode {
4444

4545
class TllmGenFmhaRunnerCache {
4646
public:
47-
using Key = std::tuple<Data_type, Data_type, Data_type>;
47+
using Key = std::tuple<Data_type, Data_type, Data_type, Data_type, int, int, int, int>;
4848

49-
static std::shared_ptr<TllmGenFmhaRunner> get(Data_type q_data_type, Data_type kv_data_type,
50-
Data_type o_data_type) {
49+
static std::shared_ptr<TllmGenFmhaRunner> get(Data_type q_data_type, Data_type k_data_type,
50+
Data_type v_data_type, Data_type o_data_type,
51+
int num_elts_sage_q = 0, int num_elts_sage_k = 0,
52+
int num_elts_sage_p = 0, int num_elts_sage_v = 0) {
5153
static std::unordered_map<Key, std::shared_ptr<TllmGenFmhaRunner>, KeyHash> cache;
5254
static std::mutex cache_mutex;
53-
Key key = std::make_tuple(q_data_type, kv_data_type, o_data_type);
55+
Key key = std::make_tuple(q_data_type, k_data_type, v_data_type, o_data_type, num_elts_sage_q,
56+
num_elts_sage_k, num_elts_sage_p, num_elts_sage_v);
5457

5558
std::lock_guard<std::mutex> lock(cache_mutex);
5659
auto it = cache.find(key);
5760
if (it != cache.end()) {
5861
return it->second;
5962
} else {
60-
auto runner = std::make_shared<TllmGenFmhaRunner>(q_data_type, kv_data_type, o_data_type);
63+
auto runner = std::make_shared<TllmGenFmhaRunner>(
64+
q_data_type, k_data_type, v_data_type, o_data_type, num_elts_sage_q, num_elts_sage_k,
65+
num_elts_sage_p, num_elts_sage_v);
6166
cache.emplace(key, runner);
6267
return runner;
6368
}
@@ -68,7 +73,10 @@ class TllmGenFmhaRunnerCache {
6873
std::size_t operator()(const Key& k) const {
6974
return std::hash<int>()(static_cast<int>(std::get<0>(k))) ^
7075
(std::hash<int>()(static_cast<int>(std::get<1>(k))) << 1) ^
71-
(std::hash<int>()(static_cast<int>(std::get<2>(k))) << 2);
76+
(std::hash<int>()(static_cast<int>(std::get<2>(k))) << 2) ^
77+
(std::hash<int>()(static_cast<int>(std::get<3>(k))) << 3) ^
78+
(std::hash<int>()(std::get<4>(k)) << 4) ^ (std::hash<int>()(std::get<5>(k)) << 5) ^
79+
(std::hash<int>()(std::get<6>(k)) << 6) ^ (std::hash<int>()(std::get<7>(k)) << 7);
7280
}
7381
};
7482
};
@@ -96,7 +104,9 @@ void trtllm_paged_attention_launcher(
96104
FLASHINFER_ERROR(err_msg.str());
97105
}
98106

99-
auto fmha_runner = TllmGenFmhaRunnerCache::get(q_data_type, kv_data_type, o_data_type);
107+
// For paged attention, K and V have the same dtype (kv_data_type).
108+
auto fmha_runner =
109+
TllmGenFmhaRunnerCache::get(q_data_type, kv_data_type, kv_data_type, o_data_type);
100110
TllmGenFmhaRunnerParams runner_params;
101111

102112
// Common params
@@ -225,6 +235,8 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {
225235
return Data_type::DATA_TYPE_E4M3;
226236
} else if (dtype == dl_float8_e5m2) {
227237
return Data_type::DATA_TYPE_E5M2;
238+
} else if (dtype == dl_int8) {
239+
return Data_type::DATA_TYPE_INT8;
228240
} else if (dtype == dl_uint8) {
229241
// fp4 tensor is not supported in torch and use uint8_t as container.
230242
return Data_type::DATA_TYPE_E2M1;
@@ -493,22 +505,27 @@ void trtllm_paged_attention_context(
493505
void trtllm_ragged_attention_launcher(
494506
void* out, void* query, void* key, void* value, void* workspace_buffer, int* seq_lens,
495507
int* cum_seq_lens_q, int* cum_seq_lens_kv, float* attention_sinks, float* lse,
496-
Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len,
497-
int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
498-
int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale,
499-
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
500-
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
501-
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
502-
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
503-
float skip_softmax_threshold_scale_factor, bool skips_softmax, int64_t workspace_size,
508+
Data_type q_data_type, Data_type k_data_type, Data_type v_data_type, Data_type o_data_type,
509+
int64_t max_q_len, int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads,
510+
int64_t head_dim_qk, int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv,
511+
double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr,
512+
const float* bmm2_scale_ptr, double o_sf_scale, int64_t batch_size, int64_t window_left,
513+
int64_t sm_count, bool enable_pdl, bool is_causal, int64_t k_stride_keys_values,
514+
int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values,
515+
int64_t v_stride_heads, int64_t v_stride_batch, float skip_softmax_threshold_scale_factor,
516+
bool skips_softmax, int64_t workspace_size, const float* sage_attn_sfs_q,
517+
const float* sage_attn_sfs_k, const float* sage_attn_sfs_p, const float* sage_attn_sfs_v,
518+
int num_elts_sage_q, int num_elts_sage_k, int num_elts_sage_p, int num_elts_sage_v,
504519
cudaStream_t stream) {
505520
if (num_qo_heads % num_kv_heads != 0) {
506521
std::ostringstream err_msg;
507522
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
508523
<< " and num_qo_heads: " << num_qo_heads;
509524
FLASHINFER_ERROR(err_msg.str());
510525
}
511-
auto fmha_runner = TllmGenFmhaRunnerCache::get(q_data_type, kv_data_type, o_data_type);
526+
auto fmha_runner = TllmGenFmhaRunnerCache::get(q_data_type, k_data_type, v_data_type, o_data_type,
527+
num_elts_sage_q, num_elts_sage_k, num_elts_sage_p,
528+
num_elts_sage_v);
512529
TllmGenFmhaRunnerParams runner_params;
513530

514531
runner_params.qPtr = query;
@@ -576,6 +593,12 @@ void trtllm_ragged_attention_launcher(
576593
runner_params.mSkipsSoftmaxWhenPossible = skips_softmax;
577594
runner_params.mSkipSoftmaxThresholdScaleFactor = skip_softmax_threshold_scale_factor;
578595

596+
// SageAttention scaling factors.
597+
runner_params.ptrSageAttnSfsQ = sage_attn_sfs_q;
598+
runner_params.ptrSageAttnSfsK = sage_attn_sfs_k;
599+
runner_params.ptrSageAttnSfsP = sage_attn_sfs_p;
600+
runner_params.ptrSageAttnSfsV = sage_attn_sfs_v;
601+
579602
auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);
580603
if (!foundKernels) {
581604
std::ostringstream err_msg;
@@ -586,16 +609,18 @@ void trtllm_ragged_attention_launcher(
586609
fmha_runner->run(runner_params);
587610
}
588611

589-
void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, TensorView value,
590-
TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len,
591-
int64_t max_kv_len, Variant<double, ffi::Tensor> bmm1_scale,
592-
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
593-
int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
594-
TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl,
595-
bool is_causal, int64_t workspace_size,
596-
Optional<TensorView> attention_sinks,
597-
Optional<float> skip_softmax_threshold_scale_factor,
598-
Optional<TensorView> lse) {
612+
void trtllm_ragged_attention(
613+
TensorView out, TensorView query, TensorView key, TensorView value, TensorView workspace_buffer,
614+
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
615+
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
616+
double o_sf_scale, int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
617+
TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, bool is_causal,
618+
int64_t workspace_size, Optional<TensorView> attention_sinks,
619+
Optional<float> skip_softmax_threshold_scale_factor, Optional<TensorView> lse,
620+
Optional<TensorView> sage_attn_sfs_q, Optional<TensorView> sage_attn_sfs_k,
621+
Optional<TensorView> sage_attn_sfs_p, Optional<TensorView> sage_attn_sfs_v,
622+
int64_t num_elts_per_sage_attn_blk_q, int64_t num_elts_per_sage_attn_blk_k,
623+
int64_t num_elts_per_sage_attn_blk_p, int64_t num_elts_per_sage_attn_blk_v) {
599624
float* attention_sinks_ptr = nullptr;
600625
if (attention_sinks.has_value()) {
601626
TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32)
@@ -613,7 +638,8 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
613638
TVM_FFI_ICHECK_EQ(value.ndim(), 3) << "value must be a 3D tensor";
614639

615640
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
616-
auto kv_data_type = dl_dtype_to_tllm_data_type(key.dtype());
641+
auto k_data_type = dl_dtype_to_tllm_data_type(key.dtype());
642+
auto v_data_type = dl_dtype_to_tllm_data_type(value.dtype());
617643
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
618644
const auto stream = get_stream(query.device());
619645
int num_qo_heads = query.size(1);
@@ -629,6 +655,20 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
629655
int v_stride_heads = value.stride(1);
630656
int v_stride_batch = value.numel();
631657

658+
// SageAttention scaling factor pointers.
659+
const float* sage_attn_sfs_q_ptr =
660+
sage_attn_sfs_q.has_value() ? static_cast<const float*>(sage_attn_sfs_q.value().data_ptr())
661+
: nullptr;
662+
const float* sage_attn_sfs_k_ptr =
663+
sage_attn_sfs_k.has_value() ? static_cast<const float*>(sage_attn_sfs_k.value().data_ptr())
664+
: nullptr;
665+
const float* sage_attn_sfs_p_ptr =
666+
sage_attn_sfs_p.has_value() ? static_cast<const float*>(sage_attn_sfs_p.value().data_ptr())
667+
: nullptr;
668+
const float* sage_attn_sfs_v_ptr =
669+
sage_attn_sfs_v.has_value() ? static_cast<const float*>(sage_attn_sfs_v.value().data_ptr())
670+
: nullptr;
671+
632672
auto maybe_bmm1_scale_value = bmm1_scale.as<double>();
633673
auto maybe_bmm2_scale_value = bmm2_scale.as<double>();
634674
auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>();
@@ -658,12 +698,17 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
658698
out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(),
659699
workspace_buffer.data_ptr(), static_cast<int*>(seq_lens.data_ptr()),
660700
static_cast<int*>(cum_seq_lens_q.data_ptr()), static_cast<int*>(cum_seq_lens_kv.data_ptr()),
661-
attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len,
662-
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
663-
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
664-
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
665-
v_stride_keys_values, v_stride_heads, v_stride_batch,
666-
skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, stream);
701+
attention_sinks_ptr, lse_ptr, q_data_type, k_data_type, v_data_type, o_data_type, max_q_len,
702+
max_kv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv,
703+
bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
704+
batch_size, window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values,
705+
k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, v_stride_batch,
706+
skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, sage_attn_sfs_q_ptr,
707+
sage_attn_sfs_k_ptr, sage_attn_sfs_p_ptr, sage_attn_sfs_v_ptr,
708+
static_cast<int>(num_elts_per_sage_attn_blk_q),
709+
static_cast<int>(num_elts_per_sage_attn_blk_k),
710+
static_cast<int>(num_elts_per_sage_attn_blk_p),
711+
static_cast<int>(num_elts_per_sage_attn_blk_v), stream);
667712
}
668713

669714
namespace trtllm_cubin_loader {

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ArtifactPath:
135135
When compiling new cubins for backend directories, update the corresponding path.
136136
"""
137137

138-
TRTLLM_GEN_FMHA: str = "82f4c77d9cf83e3fcf105feda4ce3445100ab491/fmha/trtllm-gen/"
138+
TRTLLM_GEN_FMHA: str = "ce9168f3a3f60ffaccbaf6b2ee23642d8207b3b7/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140140
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
@@ -155,7 +155,7 @@ class CheckSumHash:
155155
"""
156156

157157
TRTLLM_GEN_FMHA: str = (
158-
"56c95fbe5d1b5d0d9ded7706e1c0b7ebf0582d9cfd2f9382acd878b6b9d58c89"
158+
"6a05a464e0101612a7598273bda641633b6db15abd76bd7f9a94a07646c7127c"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161161
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"

flashinfer/prefill.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3692,6 +3692,13 @@ def trtllm_ragged_attention_deepseek(
36923692
skip_softmax_threshold_scale_factor: Optional[float] = None,
36933693
out: Optional[torch.Tensor] = None,
36943694
lse: Optional[torch.Tensor] = None,
3695+
sage_attn_sfs: Tuple[
3696+
Optional[torch.Tensor],
3697+
Optional[torch.Tensor],
3698+
Optional[torch.Tensor],
3699+
Optional[torch.Tensor],
3700+
] = (None, None, None, None),
3701+
num_elts_per_sage_attn_blk: Tuple[int, int, int, int] = (0, 0, 0, 0),
36953702
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
36963703
"""
36973704
Parameters
@@ -3766,7 +3773,7 @@ def trtllm_ragged_attention_deepseek(
37663773
if out is None:
37673774
# FP8 inputs produce bfloat16 output by default (TRT-LLM kernels
37683775
# do not support FP8 output for ragged attention)
3769-
if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
3776+
if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.int8):
37703777
out_dtype = torch.bfloat16
37713778
else:
37723779
out_dtype = query.dtype
@@ -3806,6 +3813,8 @@ def trtllm_ragged_attention_deepseek(
38063813
assert bmm2_scale.dtype == torch.float32
38073814

38083815
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
3816+
sage_attn_sfs_q, sage_attn_sfs_k, sage_attn_sfs_p, sage_attn_sfs_v = sage_attn_sfs
3817+
num_elts_sage_q, num_elts_sage_k, num_elts_sage_p, num_elts_sage_v = num_elts_per_sage_attn_blk
38093818
run_func(
38103819
out,
38113820
query,
@@ -3829,6 +3838,14 @@ def trtllm_ragged_attention_deepseek(
38293838
attention_sinks,
38303839
skip_softmax_threshold_scale_factor,
38313840
lse,
3841+
sage_attn_sfs_q,
3842+
sage_attn_sfs_k,
3843+
sage_attn_sfs_p,
3844+
sage_attn_sfs_v,
3845+
num_elts_sage_q,
3846+
num_elts_sage_k,
3847+
num_elts_sage_p,
3848+
num_elts_sage_v,
38323849
)
38333850
if return_lse:
38343851
assert lse is not None, (

0 commit comments

Comments
 (0)