@@ -44,20 +44,25 @@ enum class TllmPagedAttentionMode {
4444
4545class TllmGenFmhaRunnerCache {
4646 public:
47- using Key = std::tuple<Data_type, Data_type, Data_type>;
47+ using Key = std::tuple<Data_type, Data_type, Data_type, Data_type, int , int , int , int >;
4848
49- static std::shared_ptr<TllmGenFmhaRunner> get (Data_type q_data_type, Data_type kv_data_type,
50- Data_type o_data_type) {
49+ static std::shared_ptr<TllmGenFmhaRunner> get (Data_type q_data_type, Data_type k_data_type,
50+ Data_type v_data_type, Data_type o_data_type,
51+ int num_elts_sage_q = 0 , int num_elts_sage_k = 0 ,
52+ int num_elts_sage_p = 0 , int num_elts_sage_v = 0 ) {
5153 static std::unordered_map<Key, std::shared_ptr<TllmGenFmhaRunner>, KeyHash> cache;
5254 static std::mutex cache_mutex;
53- Key key = std::make_tuple (q_data_type, kv_data_type, o_data_type);
55+ Key key = std::make_tuple (q_data_type, k_data_type, v_data_type, o_data_type, num_elts_sage_q,
56+ num_elts_sage_k, num_elts_sage_p, num_elts_sage_v);
5457
5558 std::lock_guard<std::mutex> lock (cache_mutex);
5659 auto it = cache.find (key);
5760 if (it != cache.end ()) {
5861 return it->second ;
5962 } else {
60- auto runner = std::make_shared<TllmGenFmhaRunner>(q_data_type, kv_data_type, o_data_type);
63+ auto runner = std::make_shared<TllmGenFmhaRunner>(
64+ q_data_type, k_data_type, v_data_type, o_data_type, num_elts_sage_q, num_elts_sage_k,
65+ num_elts_sage_p, num_elts_sage_v);
6166 cache.emplace (key, runner);
6267 return runner;
6368 }
@@ -68,7 +73,10 @@ class TllmGenFmhaRunnerCache {
6873 std::size_t operator ()(const Key& k) const {
6974 return std::hash<int >()(static_cast <int >(std::get<0 >(k))) ^
7075 (std::hash<int >()(static_cast <int >(std::get<1 >(k))) << 1 ) ^
71- (std::hash<int >()(static_cast <int >(std::get<2 >(k))) << 2 );
76+ (std::hash<int >()(static_cast <int >(std::get<2 >(k))) << 2 ) ^
77+ (std::hash<int >()(static_cast <int >(std::get<3 >(k))) << 3 ) ^
78+ (std::hash<int >()(std::get<4 >(k)) << 4 ) ^ (std::hash<int >()(std::get<5 >(k)) << 5 ) ^
79+ (std::hash<int >()(std::get<6 >(k)) << 6 ) ^ (std::hash<int >()(std::get<7 >(k)) << 7 );
7280 }
7381 };
7482};
@@ -96,7 +104,9 @@ void trtllm_paged_attention_launcher(
96104 FLASHINFER_ERROR (err_msg.str ());
97105 }
98106
99- auto fmha_runner = TllmGenFmhaRunnerCache::get (q_data_type, kv_data_type, o_data_type);
107+ // For paged attention, K and V have the same dtype (kv_data_type).
108+ auto fmha_runner =
109+ TllmGenFmhaRunnerCache::get (q_data_type, kv_data_type, kv_data_type, o_data_type);
100110 TllmGenFmhaRunnerParams runner_params;
101111
102112 // Common params
@@ -225,6 +235,8 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {
225235 return Data_type::DATA_TYPE_E4M3;
226236 } else if (dtype == dl_float8_e5m2) {
227237 return Data_type::DATA_TYPE_E5M2;
238+ } else if (dtype == dl_int8) {
239+ return Data_type::DATA_TYPE_INT8;
228240 } else if (dtype == dl_uint8) {
229241 // fp4 tensor is not supported in torch and use uint8_t as container.
230242 return Data_type::DATA_TYPE_E2M1;
@@ -493,22 +505,27 @@ void trtllm_paged_attention_context(
493505void trtllm_ragged_attention_launcher (
494506 void * out, void * query, void * key, void * value, void * workspace_buffer, int * seq_lens,
495507 int * cum_seq_lens_q, int * cum_seq_lens_kv, float * attention_sinks, float * lse,
496- Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, int64_t max_q_len,
497- int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
498- int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv, double bmm1_scale, double bmm2_scale,
499- const float * bmm1_scale_log2_ptr, const float * bmm2_scale_ptr, double o_sf_scale,
500- int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
501- int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
502- int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
503- float skip_softmax_threshold_scale_factor, bool skips_softmax, int64_t workspace_size,
508+ Data_type q_data_type, Data_type k_data_type, Data_type v_data_type, Data_type o_data_type,
509+ int64_t max_q_len, int64_t max_kv_len, int64_t num_qo_heads, int64_t num_kv_heads,
510+ int64_t head_dim_qk, int64_t head_dim_v, int64_t sum_seq_q, int64_t sum_seq_kv,
511+ double bmm1_scale, double bmm2_scale, const float * bmm1_scale_log2_ptr,
512+ const float * bmm2_scale_ptr, double o_sf_scale, int64_t batch_size, int64_t window_left,
513+ int64_t sm_count, bool enable_pdl, bool is_causal, int64_t k_stride_keys_values,
514+ int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values,
515+ int64_t v_stride_heads, int64_t v_stride_batch, float skip_softmax_threshold_scale_factor,
516+ bool skips_softmax, int64_t workspace_size, const float * sage_attn_sfs_q,
517+ const float * sage_attn_sfs_k, const float * sage_attn_sfs_p, const float * sage_attn_sfs_v,
518+ int num_elts_sage_q, int num_elts_sage_k, int num_elts_sage_p, int num_elts_sage_v,
504519 cudaStream_t stream) {
505520 if (num_qo_heads % num_kv_heads != 0 ) {
506521 std::ostringstream err_msg;
507522 err_msg << " num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
508523 << " and num_qo_heads: " << num_qo_heads;
509524 FLASHINFER_ERROR (err_msg.str ());
510525 }
511- auto fmha_runner = TllmGenFmhaRunnerCache::get (q_data_type, kv_data_type, o_data_type);
526+ auto fmha_runner = TllmGenFmhaRunnerCache::get (q_data_type, k_data_type, v_data_type, o_data_type,
527+ num_elts_sage_q, num_elts_sage_k, num_elts_sage_p,
528+ num_elts_sage_v);
512529 TllmGenFmhaRunnerParams runner_params;
513530
514531 runner_params.qPtr = query;
@@ -576,6 +593,12 @@ void trtllm_ragged_attention_launcher(
576593 runner_params.mSkipsSoftmaxWhenPossible = skips_softmax;
577594 runner_params.mSkipSoftmaxThresholdScaleFactor = skip_softmax_threshold_scale_factor;
578595
596+ // SageAttention scaling factors.
597+ runner_params.ptrSageAttnSfsQ = sage_attn_sfs_q;
598+ runner_params.ptrSageAttnSfsK = sage_attn_sfs_k;
599+ runner_params.ptrSageAttnSfsP = sage_attn_sfs_p;
600+ runner_params.ptrSageAttnSfsV = sage_attn_sfs_v;
601+
579602 auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo (runner_params);
580603 if (!foundKernels) {
581604 std::ostringstream err_msg;
@@ -586,16 +609,18 @@ void trtllm_ragged_attention_launcher(
586609 fmha_runner->run (runner_params);
587610}
588611
589- void trtllm_ragged_attention (TensorView out, TensorView query, TensorView key, TensorView value,
590- TensorView workspace_buffer, TensorView seq_lens, int64_t max_q_len,
591- int64_t max_kv_len, Variant<double , ffi::Tensor> bmm1_scale,
592- Variant<double , ffi::Tensor> bmm2_scale, double o_sf_scale,
593- int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
594- TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl,
595- bool is_causal, int64_t workspace_size,
596- Optional<TensorView> attention_sinks,
597- Optional<float > skip_softmax_threshold_scale_factor,
598- Optional<TensorView> lse) {
612+ void trtllm_ragged_attention (
613+ TensorView out, TensorView query, TensorView key, TensorView value, TensorView workspace_buffer,
614+ TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
615+ Variant<double , ffi::Tensor> bmm1_scale, Variant<double , ffi::Tensor> bmm2_scale,
616+ double o_sf_scale, int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
617+ TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, bool is_causal,
618+ int64_t workspace_size, Optional<TensorView> attention_sinks,
619+ Optional<float > skip_softmax_threshold_scale_factor, Optional<TensorView> lse,
620+ Optional<TensorView> sage_attn_sfs_q, Optional<TensorView> sage_attn_sfs_k,
621+ Optional<TensorView> sage_attn_sfs_p, Optional<TensorView> sage_attn_sfs_v,
622+ int64_t num_elts_per_sage_attn_blk_q, int64_t num_elts_per_sage_attn_blk_k,
623+ int64_t num_elts_per_sage_attn_blk_p, int64_t num_elts_per_sage_attn_blk_v) {
599624 float * attention_sinks_ptr = nullptr ;
600625 if (attention_sinks.has_value ()) {
601626 TVM_FFI_ICHECK_EQ (attention_sinks.value ().dtype (), dl_float32)
@@ -613,7 +638,8 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
613638 TVM_FFI_ICHECK_EQ (value.ndim (), 3 ) << " value must be a 3D tensor" ;
614639
615640 auto q_data_type = dl_dtype_to_tllm_data_type (query.dtype ());
616- auto kv_data_type = dl_dtype_to_tllm_data_type (key.dtype ());
641+ auto k_data_type = dl_dtype_to_tllm_data_type (key.dtype ());
642+ auto v_data_type = dl_dtype_to_tllm_data_type (value.dtype ());
617643 auto o_data_type = dl_dtype_to_tllm_data_type (out.dtype ());
618644 const auto stream = get_stream (query.device ());
619645 int num_qo_heads = query.size (1 );
@@ -629,6 +655,20 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
629655 int v_stride_heads = value.stride (1 );
630656 int v_stride_batch = value.numel ();
631657
658+ // SageAttention scaling factor pointers.
659+ const float * sage_attn_sfs_q_ptr =
660+ sage_attn_sfs_q.has_value () ? static_cast <const float *>(sage_attn_sfs_q.value ().data_ptr ())
661+ : nullptr ;
662+ const float * sage_attn_sfs_k_ptr =
663+ sage_attn_sfs_k.has_value () ? static_cast <const float *>(sage_attn_sfs_k.value ().data_ptr ())
664+ : nullptr ;
665+ const float * sage_attn_sfs_p_ptr =
666+ sage_attn_sfs_p.has_value () ? static_cast <const float *>(sage_attn_sfs_p.value ().data_ptr ())
667+ : nullptr ;
668+ const float * sage_attn_sfs_v_ptr =
669+ sage_attn_sfs_v.has_value () ? static_cast <const float *>(sage_attn_sfs_v.value ().data_ptr ())
670+ : nullptr ;
671+
632672 auto maybe_bmm1_scale_value = bmm1_scale.as <double >();
633673 auto maybe_bmm2_scale_value = bmm2_scale.as <double >();
634674 auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as <ffi::Tensor>();
@@ -658,12 +698,17 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
658698 out.data_ptr (), query.data_ptr (), key.data_ptr (), value.data_ptr (),
659699 workspace_buffer.data_ptr (), static_cast <int *>(seq_lens.data_ptr ()),
660700 static_cast <int *>(cum_seq_lens_q.data_ptr ()), static_cast <int *>(cum_seq_lens_kv.data_ptr ()),
661- attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len,
662- num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
663- bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
664- sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
665- v_stride_keys_values, v_stride_heads, v_stride_batch,
666- skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, stream);
701+ attention_sinks_ptr, lse_ptr, q_data_type, k_data_type, v_data_type, o_data_type, max_q_len,
702+ max_kv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv,
703+ bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale,
704+ batch_size, window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values,
705+ k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, v_stride_batch,
706+ skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, sage_attn_sfs_q_ptr,
707+ sage_attn_sfs_k_ptr, sage_attn_sfs_p_ptr, sage_attn_sfs_v_ptr,
708+ static_cast <int >(num_elts_per_sage_attn_blk_q),
709+ static_cast <int >(num_elts_per_sage_attn_blk_k),
710+ static_cast <int >(num_elts_per_sage_attn_blk_p),
711+ static_cast <int >(num_elts_per_sage_attn_blk_v), stream);
667712}
668713
669714namespace trtllm_cubin_loader {
0 commit comments