@@ -2470,7 +2470,9 @@ struct MHAHelper {
24702470 const PlainTensor& alibi_slopes,
24712471 float * score_output,
24722472 size_t q_start_idx_score,
2473- const ScoreAggregationInfo* score_info_ptr) {
2473+ const ScoreAggregationInfo* score_info_ptr,
2474+ size_t batch_in_seq = 0 ,
2475+ const std::vector<PlainTensor>& sparse_attention_mask = {}) {
24742476 auto q_start = q_blk * _block_size;
24752477 auto q_end = std::min (q_start + _block_size, q_len);
24762478 auto q_cnt = q_end - q_start;
@@ -2487,6 +2489,12 @@ struct MHAHelper {
24872489 // 1 1 1 0 ...
24882490 // just computing the positions of 1 should be enough
24892491 for (size_t k_blk = 0 ; k_blk < cur_kv_len_blocks; k_blk++) {
2492+ // sparse attention mask filtering
2493+ if (!sparse_attention_mask.empty () &&
2494+ !sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2495+ // mask为false时跳过该block的GEMM
2496+ continue ;
2497+ }
24902498 auto * k_ptr = qk_scratch_b.ptr <DATA_TYPE>(k_blk, hk);
24912499 _qk_gemm[q_cnt - 1 ]->executeGemm (q_cnt < _block_size,
24922500 q_ptr,
@@ -2496,6 +2504,19 @@ struct MHAHelper {
24962504 _qk_scratch_a ? _qk_scratch_a.ptr <DATA_TYPE>(ithr, 0 ) : nullptr );
24972505 }
24982506
2507+ // sparse attention mask: 对应q_blk, k_blk的mask为false时,score置为-inf(整个q_blk的score buffer)
2508+ if (!sparse_attention_mask.empty ()) {
2509+ float * score_base = _weight.ptr <float >(ithr, h - hq_beg, 0 );
2510+ for (size_t k_blk = 0 ; k_blk < cur_kv_len_blocks; k_blk++) {
2511+ if (!sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2512+ for (size_t m = 0 ; m < q_cnt; m++) {
2513+ float * score_blk = score_base + m * cur_kv_len_blocks * _block_size + k_blk * _block_size;
2514+ std::fill (score_blk, score_blk + _block_size, -std::numeric_limits<float >::infinity ());
2515+ }
2516+ }
2517+ }
2518+ }
2519+
24992520 for (size_t m = q_start; m < q_end; m++) {
25002521 // apply attention mask & sofmax
25012522 auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1 );
@@ -2563,6 +2584,11 @@ struct MHAHelper {
25632584
25642585 // for each weight block, loop through all value block
25652586 for (size_t v_blk = 0 ; v_blk < cur_kv_len_blocks; v_blk++) {
2587+ // sparse attention mask filtering for value blocks
2588+ if (!sparse_attention_mask.empty () &&
2589+ !sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, v_blk)[0 ]) {
2590+ continue ;
2591+ }
25662592 DATA_TYPE* v_ptr = nullptr ;
25672593 if (q_is_xf16 || !q_cache_is_same) {
25682594 v_ptr = wv_scratch_b.ptr <DATA_TYPE>(v_blk, hk);
@@ -2762,14 +2788,21 @@ struct MHAHelper {
27622788 size_t q_len,
27632789 size_t cur_kv_len,
27642790 const PlainTensor& alibi_slopes,
2765- float * score_output) {
2791+ float * score_output,
2792+ const std::vector<PlainTensor>& sparse_attention_mask) {
27662793# if defined(OPENVINO_ARCH_X86_64)
27672794 if (one_of (_fastpath_valid_prec, ov::element::bf16 , ov::element::f16 )) {
27682795 _gemv->tile_config ();
27692796 for (size_t pk = 0 , i = 0 ; pk < cur_kv_len; pk += _block_size, i++) {
27702797 auto block_number = block_table[i];
27712798 for (size_t pq = 0 ; pq < q_len; pq++) {
2799+ size_t q_blk = pq / _block_size;
27722800 for (size_t h = hq_beg; h < hq_end; h++) {
2801+ size_t k_blk = pk / _block_size;
2802+ // 只处理 mask==true 的 block
2803+ if (!sparse_attention_mask[0 ].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2804+ continue ;
2805+ }
27732806 (*_gemv)(
27742807 query.ptr <DATA_TYPE>(h, pq),
27752808 present_key.ptr <typename ov::element_type_traits<KEY_PREC>::value_type>(block_number, hk),
@@ -2783,7 +2816,13 @@ struct MHAHelper {
27832816 for (size_t pk = 0 , i = 0 ; pk < cur_kv_len; pk += _block_size, i++) {
27842817 auto block_number = block_table[i];
27852818 for (size_t pq = 0 ; pq < q_len; pq++) {
2819+ size_t q_blk = pq / _block_size;
27862820 for (size_t h = hq_beg; h < hq_end; h++) {
2821+ size_t k_blk = pk / _block_size;
2822+ // 只处理 mask==true 的 block
2823+ if (!sparse_attention_mask[0 ].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2824+ continue ;
2825+ }
27872826 if constexpr (KEY_PREC == ov::element::u8 || KEY_PREC == ov::element::u4) {
27882827 dot_product_block_quantized<DATA_TYPE, KEY_PREC>(
27892828 query.ptr <DATA_TYPE>(h, pq),
@@ -2811,6 +2850,7 @@ struct MHAHelper {
28112850# endif
28122851
28132852 for (size_t pq = 0 ; pq < q_len; pq++) {
2853+ size_t q_blk = pq / _block_size;
28142854 for (size_t h = hq_beg; h < hq_end; h++) {
28152855 // apply attention mask & sofmax
28162856 float * alibi_lookup = nullptr ;
@@ -2819,6 +2859,14 @@ struct MHAHelper {
28192859 alibi_slope = alibi_slopes.ptr <float >()[h];
28202860 alibi_lookup = _alibi_lookup.ptr <float >() + _alibi_lookup.m_dims [0 ] - cur_kv_len;
28212861 }
2862+ // mask: softmax前将mask==false的位置赋值为-inf
2863+ // TODO: parallel process outside q loop
2864+ for (size_t k = 0 ; k < cur_kv_len; ++k) {
2865+ size_t k_blk = k / _block_size;
2866+ if (!sparse_attention_mask[0 ].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2867+ _weight.ptr <float >(ithr, h - hq_beg, pq)[k] = -std::numeric_limits<float >::infinity ();
2868+ }
2869+ }
28222870 attn_softmax_kernel<float >(_weight.ptr <float >(ithr, h - hq_beg, pq),
28232871 _weight.ptr <float >(ithr, h - hq_beg, pq),
28242872 _d_scale,
@@ -2845,7 +2893,13 @@ struct MHAHelper {
28452893 for (size_t pv = 0 , i = 0 ; pv < cur_kv_len; pv += _block_size, i++) {
28462894 auto block_number = block_table[i];
28472895 for (size_t pq = 0 ; pq < q_len; pq++) {
2896+ size_t q_blk = pq / _block_size;
28482897 for (size_t h = hq_beg; h < hq_end; h++) {
2898+ size_t k_blk = pv / _block_size;
2899+ // 只处理 mask==true 的 block
2900+ if (!sparse_attention_mask[0 ].ptr <bool >(h, q_blk, k_blk)[0 ]) {
2901+ continue ;
2902+ }
28492903 if constexpr (one_of (VALUE_PREC, ov::element::u8 , ov::element::u4)) {
28502904 attn_acc_value_block_quantized<uint8_t , VALUE_PREC>(
28512905 _output.ptr <float >(ithr, pq, h),
@@ -3211,7 +3265,8 @@ struct MHA {
32113265 const PlainTensor& block_indices,
32123266 const PlainTensor& block_indices_begins,
32133267 const PlainTensor& alibi_slopes,
3214- const PlainTensor& score_aggregation_window) {
3268+ const PlainTensor& score_aggregation_window,
3269+ const std::vector<PlainTensor>& sparse_attention_mask) {
32153270 auto Hk = v_cache.m_dims [1 ];
32163271
32173272 constexpr bool q_is_xf16 = one_of (precision_of<DATA_TYPE>::value, ov::element::bf16 , ov::element::f16 );
@@ -3361,7 +3416,8 @@ struct MHA {
33613416 1UL ,
33623417 cur_kv_len,
33633418 alibi_slopes,
3364- score_output);
3419+ score_output,
3420+ sparse_attention_mask);
33653421 } else {
33663422 const auto batch_in_reorder = item.batch_in_reorder ;
33673423 const auto q_blk = item.q_block_id ;
@@ -3431,7 +3487,9 @@ struct MHA {
34313487 alibi_slopes,
34323488 score_output,
34333489 q_start_idx_score,
3434- score_info_ptr);
3490+ score_info_ptr,
3491+ batch_in_seq,
3492+ sparse_attention_mask);
34353493 }
34363494# else
34373495 _helper.exec_kernel_multiple (
@@ -3452,8 +3510,10 @@ struct MHA {
34523510 alibi_slopes,
34533511 score_output,
34543512 q_start_idx_score,
3455- score_info_ptr);
3456- # endif
3513+ score_info_ptr,
3514+ batch_in_seq,
3515+ sparse_attention_mask);
3516+ # endif
34573517 }
34583518 });
34593519 if (output_score) {
@@ -3489,41 +3549,55 @@ struct MHA {
34893549 const PlainTensor& block_indices,
34903550 const PlainTensor& block_indices_begins,
34913551 const PlainTensor& alibi_slopes,
3492- const PlainTensor& score_aggregation_window) {
3552+ const PlainTensor& score_aggregation_window,
3553+ const std::vector<PlainTensor>& sparse_attention_mask) {
34933554 _workitems.reset (query, past_lens, subsequence_begins, _helper._block_size );
34943555 if (output_score) {
34953556 _helper.init_score_buffers (past_lens, subsequence_begins, score_aggregation_window);
34963557 }
34973558
34983559 auto nthr = static_cast <size_t >(parallel_get_max_threads ());
34993560
3500- if (past_lens.m_dims [0 ] >= nthr || _workitems.get_reorder_max_batch_size () > 0 ) {
3501- exec_loop_mixed (query,
3502- present_key,
3503- present_value,
3504- output_emb,
3505- output_score,
3506- max_context_len,
3507- past_lens,
3508- subsequence_begins,
3509- block_indices,
3510- block_indices_begins,
3511- alibi_slopes,
3512- score_aggregation_window);
3513- } else {
3514- _helper.exec_loop_bhl (query,
3515- present_key,
3516- present_value,
3517- output_emb,
3518- output_score,
3519- max_context_len,
3520- past_lens,
3521- subsequence_begins,
3522- block_indices,
3523- block_indices_begins,
3524- alibi_slopes,
3525- score_aggregation_window);
3526- }
3561+ // if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) {
3562+ // exec_loop_mixed(query,
3563+ // present_key,
3564+ // present_value,
3565+ // output_emb,
3566+ // output_score,
3567+ // max_context_len,
3568+ // past_lens,
3569+ // subsequence_begins,
3570+ // block_indices,
3571+ // block_indices_begins,
3572+ // alibi_slopes,
3573+ // score_aggregation_window);
3574+ // } else {
3575+ // _helper.exec_loop_bhl(query,
3576+ // present_key,
3577+ // present_value,
3578+ // output_emb,
3579+ // output_score,
3580+ // max_context_len,
3581+ // past_lens,
3582+ // subsequence_begins,
3583+ // block_indices,
3584+ // block_indices_begins,
3585+ // alibi_slopes,
3586+ // score_aggregation_window);
3587+ // }
3588+ exec_loop_mixed (query,
3589+ present_key,
3590+ present_value,
3591+ output_emb,
3592+ output_score,
3593+ max_context_len,
3594+ past_lens,
3595+ subsequence_begins,
3596+ block_indices,
3597+ block_indices_begins,
3598+ alibi_slopes,
3599+ score_aggregation_window,
3600+ sparse_attention_mask);
35273601 }
35283602};
35293603
@@ -3565,7 +3639,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
35653639 PlainTensor& rotation_deltas,
35663640 PlainTensor& rotation_trig_lut,
35673641 PlainTensor& output_emb,
3568- PlainTensor& output_score) {
3642+ PlainTensor& output_score,
3643+ std::vector<PlainTensor>& sparse_attention_mask) {
35693644 q.reset (inputs[ID_Q]); // [B_token, H * S]
35703645 k.reset (inputs[ID_K]);
35713646 v.reset (inputs[ID_V]);
@@ -3741,6 +3816,96 @@ struct AttentionExecutor : public PagedAttentionExecutor {
37413816 // TODO: enable block_size to be multiple of 32
37423817 OPENVINO_ASSERT (block_size == 32 , " CPU: block size must be 32, current: " , block_size);
37433818
3819+ // --- 创建和初始化 sparse_attention_mask ---
3820+ sparse_attention_mask.clear ();
3821+ size_t k_blocks = div_up (max_context_len, block_size);
3822+ for (size_t b = 0 ; b < B_seq; ++b) {
3823+ auto q_len_for_batch = subsequence_begins.ptr <int32_t >()[b + 1 ] - subsequence_begins.ptr <int32_t >()[b];
3824+ size_t q_blocks = div_up (static_cast <size_t >(q_len_for_batch), block_size);
3825+ PlainTensor mask;
3826+ mask.resize <bool >({H, q_blocks, k_blocks});
3827+ // 默认全部初始化为 false
3828+ std::memset (mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
3829+ // 可选:全部激活(示例)
3830+ for (size_t h = 0 ; h < H; ++h) {
3831+ for (size_t q_blk = 0 ; q_blk < q_blocks; ++q_blk) {
3832+ for (size_t k_blk = 0 ; k_blk < k_blocks; ++k_blk) {
3833+ // 对称点 (q_blocks/2, k_blocks/2),左上和右下为true,其余为false
3834+ // bool left_top = (q_blk < q_blocks / 2) && (k_blk < k_blocks / 2);
3835+ // bool right_bottom = (q_blk >= (q_blocks + 1) / 2) && (k_blk >= (k_blocks + 1) / 2);
3836+ // mask.ptr<bool>(h, q_blk, k_blk)[0] = left_top || right_bottom;
3837+
3838+ // 所有mask为true
3839+ mask.ptr <bool >(h, q_blk, k_blk)[0 ] = true ;
3840+
3841+ // // 中间一个block设置为false
3842+ // if (q_blk == q_blocks / 2 && k_blk == k_blocks / 2) {
3843+ // mask.ptr<bool>(h, q_blk, k_blk)[0] = false;
3844+ // } else {
3845+ // mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
3846+ // }
3847+ }
3848+ }
3849+ }
3850+ sparse_attention_mask.push_back (std::move (mask));
3851+ }
3852+
3853+ // --- 广播 sparse_attention_mask 以支持不同 block size ---
3854+ // 输入的 mask 可能是[h, q_blocks_orig, k_blocks_orig],需要广播到[h, q_blocks, k_blocks]
3855+ auto broadcast_sparse_attention_mask =
3856+ [](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
3857+ if (src_block_size == dst_block_size)
3858+ return ;
3859+ if (src_block_size % dst_block_size != 0 ) {
3860+ OPENVINO_THROW (" not supported 当sparse_attention_BlockSize=" ,
3861+ src_block_size,
3862+ " 但block_size=" ,
3863+ dst_block_size);
3864+ }
3865+ size_t scale = src_block_size / dst_block_size;
3866+ for (auto & mask : mask_vec) {
3867+ auto shape = mask.shape ();
3868+ size_t H = shape[0 ];
3869+ size_t q_blocks_orig = shape[1 ];
3870+ size_t k_blocks_orig = shape[2 ];
3871+ size_t q_blocks = q_blocks_orig * scale;
3872+ size_t k_blocks = k_blocks_orig * scale;
3873+ PlainTensor new_mask;
3874+ new_mask.resize <bool >({H, q_blocks, k_blocks});
3875+ std::memset (new_mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
3876+ for (size_t h = 0 ; h < H; ++h) {
3877+ for (size_t q_blk = 0 ; q_blk < q_blocks_orig; ++q_blk) {
3878+ for (size_t k_blk = 0 ; k_blk < k_blocks_orig; ++k_blk) {
3879+ bool val = mask.ptr <bool >(h, q_blk, k_blk)[0 ];
3880+ for (size_t dq = 0 ; dq < scale; ++dq) {
3881+ for (size_t dk = 0 ; dk < scale; ++dk) {
3882+ new_mask.ptr <bool >(h, q_blk * scale + dq, k_blk * scale + dk)[0 ] = val;
3883+ }
3884+ }
3885+ }
3886+ }
3887+ }
3888+ mask = std::move (new_mask);
3889+ }
3890+ };
3891+ // 原始sparse attention mask的block_size,后续通过Page Attention Node参数指定
3892+ // const size_t sparse_attention_BlockSize = 128;
3893+ const size_t sparse_attention_BlockSize = 32 ;
3894+ // 只支持 block_size <= sparse_attention_BlockSize 且 sparse_attention_BlockSize 是 block_size 的整数倍
3895+ if (block_size != sparse_attention_BlockSize) {
3896+ if (block_size > sparse_attention_BlockSize) {
3897+ OPENVINO_THROW (" not supported: block_size > sparse_attention_BlockSize" );
3898+ }
3899+ if (sparse_attention_BlockSize % block_size != 0 ) {
3900+ OPENVINO_THROW (" not supported: sparse_attention_BlockSize " ,
3901+ sparse_attention_BlockSize,
3902+ " 不是 block_size " ,
3903+ block_size,
3904+ " 的整数倍" );
3905+ }
3906+ broadcast_sparse_attention_mask (sparse_attention_mask, sparse_attention_BlockSize, block_size);
3907+ }
3908+
37443909 _helper.init (H,
37453910 S,
37463911 SV,
@@ -3823,6 +3988,10 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38233988 PlainTensor output_emb;
38243989 PlainTensor output_score;
38253990
3991+ std::vector<PlainTensor>
3992+ sparse_attention_mask; // 每个vector对应一个batch,每个PlainTensor对应一个batch,格式:[H,
3993+ // q_blocks, k_blocks], bool类型
3994+
38263995 init (inputs,
38273996 outputs,
38283997 q,
@@ -3843,7 +4012,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38434012 rotation_deltas,
38444013 rotation_trig_lut,
38454014 output_emb,
3846- output_score);
4015+ output_score,
4016+ sparse_attention_mask);
38474017
38484018 if (rotated_block_indices) {
38494019 // Rotate kv cache currently doesn't support quantized cache.
@@ -3869,7 +4039,8 @@ struct AttentionExecutor : public PagedAttentionExecutor {
38694039 block_indices,
38704040 block_indices_begins,
38714041 alibi_slopes,
3872- score_aggregation_window);
4042+ score_aggregation_window,
4043+ sparse_attention_mask);
38734044 }
38744045};
38754046#endif
0 commit comments