@@ -2950,7 +2950,8 @@ struct MHAHelper {
29502950 const PlainTensor& block_indices,
29512951 const PlainTensor& block_indices_begins,
29522952 const PlainTensor& alibi_slopes,
2953- const PlainTensor& score_aggregation_window) {
2953+ const PlainTensor& score_aggregation_window,
2954+ const std::vector<PlainTensor>& sparse_attention_mask = {}) {
29542955 auto B = past_lens.size (0 );
29552956 auto q_len = query.size (2 );
29562957 auto kv_len_in_blocks = div_up (max_context_len, _block_size);
@@ -2995,13 +2996,20 @@ struct MHAHelper {
29952996
29962997 // kv_len must be valid
29972998 auto pk = pk_in_blocks * _block_size;
2999+ size_t k_blk = pk / _block_size;
29983000 if (pk < context_len) {
29993001 auto block_number = block_indices.ptr <int32_t >()[block_indices_begins.ptr <int32_t >()[b] + pk_in_blocks];
30003002# if defined(OPENVINO_ARCH_X86_64)
30013003 if (one_of (_fastpath_valid_prec, ov::element::bf16 , ov::element::f16 )) {
30023004 _gemv->tile_config ();
30033005 for (size_t pq = 0 ; pq < q_len; pq++) {
3006+ size_t q_blk = pq / _block_size;
30043007 for (size_t h = hq_beg; h < hq_end; h++) {
3008+ // sparse attention mask check
3009+ if (!sparse_attention_mask.empty () &&
3010+ !sparse_attention_mask[b].ptr <bool >(h, q_blk, k_blk)[0 ]) {
3011+ continue ;
3012+ }
30053013 (*_gemv)(
30063014 query.ptr <DATA_TYPE>(b, h, pq),
30073015 key_cache.ptr <typename ov::element_type_traits<KEY_PREC>::value_type>(block_number, hk),
@@ -3012,7 +3020,13 @@ struct MHAHelper {
30123020 } else {
30133021# endif
30143022 for (size_t pq = 0 ; pq < q_len; pq++) {
3023+ size_t q_blk = pq / _block_size;
30153024 for (size_t h = hq_beg; h < hq_end; h++) {
3025+ // sparse attention mask check
3026+ if (!sparse_attention_mask.empty () &&
3027+ !sparse_attention_mask[b].ptr <bool >(h, q_blk, k_blk)[0 ]) {
3028+ continue ;
3029+ }
30163030 if constexpr (one_of (KEY_PREC, ov::element::u8 , ov::element::u4)) {
30173031 dot_product_block_quantized<DATA_TYPE, KEY_PREC>(
30183032 query.ptr <DATA_TYPE>(b, h, pq),
@@ -3050,6 +3064,16 @@ struct MHAHelper {
30503064 alibi_slope = alibi_slopes.ptr <float >()[h];
30513065 alibi_lookup = _alibi_lookup.ptr <float >() + _alibi_lookup.m_dims [0 ] - cur_kv_len;
30523066 }
3067+ // sparse attention mask: mask==false 的位置赋值为-inf
3068+ if (!sparse_attention_mask.empty ()) {
3069+ size_t q_blk = pq / _block_size;
3070+ for (size_t k = 0 ; k < cur_kv_len; ++k) {
3071+ size_t k_blk = k / _block_size;
3072+ if (!sparse_attention_mask[b].ptr <bool >(h, q_blk, k_blk)[0 ]) {
3073+ _weight_bhl.ptr <float >(b, h, pq)[k] = -std::numeric_limits<float >::infinity ();
3074+ }
3075+ }
3076+ }
30533077 attn_softmax_kernel<float >(_weight_bhl.ptr <float >(b, h, pq),
30543078 _weight_bhl.ptr <float >(b, h, pq),
30553079 _d_scale,
@@ -3105,8 +3129,14 @@ struct MHAHelper {
31053129 // kv_len must be valid
31063130 if (pv < context_len) {
31073131 auto block_number = block_indices.ptr <int32_t >()[block_indices_begins.ptr <int32_t >()[b] + pv_in_blocks];
3132+ size_t k_blk = pv / _block_size;
31083133 for (size_t pq = 0 ; pq < q_len; pq++) {
3134+ size_t q_blk = pq / _block_size;
31093135 for (size_t h = hq_beg; h < hq_end; h++) {
3136+ // sparse attention mask check
3137+ if (!sparse_attention_mask.empty () && !sparse_attention_mask[b].ptr <bool >(h, q_blk, k_blk)[0 ]) {
3138+ continue ;
3139+ }
31103140 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
31113141 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
31123142 _output_bhl.ptr <float >(b, pv_in_blocks, h, pq),
@@ -3558,46 +3588,48 @@ struct MHA {
35583588
35593589 auto nthr = static_cast <size_t >(parallel_get_max_threads ());
35603590
3561- // if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3562- // exec_loop_mixed(query,
3563- // present_key,
3564- // present_value,
3565- // output_emb,
3566- // output_score,
3567- // max_context_len,
3568- // past_lens,
3569- // subsequence_begins,
3570- // block_indices,
3571- // block_indices_begins,
3572- // alibi_slopes,
3573- // score_aggregation_window);
3574- // } else {
3575- // _helper.exec_loop_bhl(query,
3576- // present_key,
3577- // present_value,
3578- // output_emb,
3579- // output_score,
3580- // max_context_len,
3581- // past_lens,
3582- // subsequence_begins,
3583- // block_indices,
3584- // block_indices_begins,
3585- // alibi_slopes,
3586- // score_aggregation_window);
3587- // }
3588- exec_loop_mixed (query,
3589- present_key,
3590- present_value,
3591- output_emb,
3592- output_score,
3593- max_context_len,
3594- past_lens,
3595- subsequence_begins,
3596- block_indices,
3597- block_indices_begins,
3598- alibi_slopes,
3599- score_aggregation_window,
3600- sparse_attention_mask);
3591+ if (past_lens.m_dims [0 ] >= nthr || _workitems.get_reorder_max_batch_size () > 0 ) {
3592+ exec_loop_mixed (query,
3593+ present_key,
3594+ present_value,
3595+ output_emb,
3596+ output_score,
3597+ max_context_len,
3598+ past_lens,
3599+ subsequence_begins,
3600+ block_indices,
3601+ block_indices_begins,
3602+ alibi_slopes,
3603+ score_aggregation_window,
3604+ sparse_attention_mask);
3605+ } else {
3606+ _helper.exec_loop_bhl (query,
3607+ present_key,
3608+ present_value,
3609+ output_emb,
3610+ output_score,
3611+ max_context_len,
3612+ past_lens,
3613+ subsequence_begins,
3614+ block_indices,
3615+ block_indices_begins,
3616+ alibi_slopes,
3617+ score_aggregation_window,
3618+ sparse_attention_mask);
3619+ }
3620+ // exec_loop_mixed(query,
3621+ // present_key,
3622+ // present_value,
3623+ // output_emb,
3624+ // output_score,
3625+ // max_context_len,
3626+ // past_lens,
3627+ // subsequence_begins,
3628+ // block_indices,
3629+ // block_indices_begins,
3630+ // alibi_slopes,
3631+ // score_aggregation_window,
3632+ // sparse_attention_mask);
36013633 }
36023634};
36033635
0 commit comments