@@ -584,6 +584,8 @@ struct MHAHelper {
584584 std::vector<ScoreAggregationInfo> _score_infos;
585585
586586 PlainTensor _block_rotation_coefficient_scratch;
587+ // Block size used when generating sparse_attention_mask (0 means unspecified/equal to _block_size)
588+ size_t _sparse_mask_block_size = 0 ;
587589
588590 MHAHelper () {
589591 _weight.resize <float >({size_t {1 }, size_t {1 }, size_t {1 }, size_t {1 }});
@@ -853,12 +855,25 @@ struct MHAHelper {
853855 // 1 1 0 0 ...
854856 // 1 1 1 0 ...
855857 // just computing the positions of 1 should be enough
858+ // map runtime (block_size) indices to mask (xt_block_size) indices
859+ auto map_to_mask_idx = [&](size_t q_blk_rt, size_t k_blk_rt) {
860+ if (_sparse_mask_block_size == 0 || _sparse_mask_block_size == _block_size) {
861+ return std::pair<size_t , size_t >{q_blk_rt, k_blk_rt};
862+ }
863+ // Only support mask block >= runtime block and divisible (checked in init)
864+ size_t scale = _sparse_mask_block_size / _block_size; // >=1
865+ size_t q_mask = q_blk_rt / scale;
866+ size_t k_mask = k_blk_rt / scale;
867+ return std::pair<size_t , size_t >{q_mask, k_mask};
868+ };
856869 for (size_t k_blk = 0 ; k_blk < cur_kv_len_blocks; k_blk++) {
857870 // sparse attention mask filtering
858- if (!sparse_attention_mask.empty () &&
859- !sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
860- // Skip GEMM for this block if mask is false
861- continue ;
871+ if (!sparse_attention_mask.empty ()) {
872+ auto [q_m, k_m] = map_to_mask_idx (q_blk, k_blk);
873+ if (!sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_m, k_m)[0 ]) {
874+ // Skip GEMM for this block if mask is false
875+ continue ;
876+ }
862877 }
863878 if (_params.is_sage_attn ) {
864879# if defined(OPENVINO_ARCH_X86_64)
@@ -900,7 +915,8 @@ struct MHAHelper {
900915 std::fill (softmax_mask_storage.begin (), softmax_mask_storage.end (), neg_inf_val);
901916 for (size_t k = 0 ; k < cur_kv_len; ++k) {
902917 size_t k_blk = k / _block_size;
903- if (sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
918+ auto [q_m, k_m] = map_to_mask_idx (q_blk, k_blk);
919+ if (sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_m, k_m)[0 ]) {
904920 softmax_mask_storage[k] = static_cast <DATA_TYPE>(0 );
905921 }
906922 }
@@ -980,9 +996,11 @@ struct MHAHelper {
980996 // for each weight block, loop through all value block
981997 for (size_t v_blk = 0 ; v_blk < cur_kv_len_blocks; v_blk++) {
982998 // sparse attention mask filtering for value blocks
983- if (!sparse_attention_mask.empty () &&
984- !sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, v_blk)[0 ]) {
985- continue ;
999+ if (!sparse_attention_mask.empty ()) {
1000+ auto [q_m, v_m] = map_to_mask_idx (q_blk, v_blk);
1001+ if (!sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_m, v_m)[0 ]) {
1002+ continue ;
1003+ }
9861004 }
9871005 DATA_TYPE* v_ptr = nullptr ;
9881006 if (q_is_xf16 || !q_cache_is_same) {
@@ -2182,47 +2200,6 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21822200 xt_block_size,
21832201 xt_threshold);
21842202
2185- // --- Broadcast sparse_attention_mask to support different block sizes ---
2186- // The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks,
2187- // k_blocks]
2188- auto broadcast_sparse_attention_mask =
2189- [](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2190- if (src_block_size == dst_block_size)
2191- return ;
2192- if (src_block_size % dst_block_size != 0 ) {
2193- OPENVINO_THROW (" not supported: sparse_attention_BlockSize=" ,
2194- src_block_size,
2195- " is not an integer multiple of block_size=" ,
2196- dst_block_size);
2197- }
2198- size_t scale = src_block_size / dst_block_size;
2199- for (auto & mask : mask_vec) {
2200- auto shape = mask.shape ();
2201- size_t H = shape[0 ];
2202- size_t q_blocks_orig = shape[1 ];
2203- size_t k_blocks_orig = shape[2 ];
2204- size_t q_blocks = q_blocks_orig * scale;
2205- size_t k_blocks = k_blocks_orig * scale;
2206- PlainTensor new_mask;
2207- new_mask.resize <bool >({H, q_blocks, k_blocks});
2208- std::memset (new_mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
2209- for (size_t h = 0 ; h < H; ++h) {
2210- for (size_t q_blk = 0 ; q_blk < q_blocks_orig; ++q_blk) {
2211- for (size_t k_blk = 0 ; k_blk < k_blocks_orig; ++k_blk) {
2212- bool val = mask.ptr <bool >(h, q_blk, k_blk)[0 ];
2213- for (size_t dq = 0 ; dq < scale; ++dq) {
2214- for (size_t dk = 0 ; dk < scale; ++dk) {
2215- new_mask.ptr <bool >(h, q_blk * scale + dq, k_blk * scale + dk)[0 ] = val;
2216- }
2217- }
2218- }
2219- }
2220- }
2221- mask = std::move (new_mask);
2222- }
2223- };
2224- // The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2225- // parameter const size_t sparse_attention_BlockSize = 128;
22262203 // Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
22272204 // multiple
22282205 if (block_size != xt_block_size) {
@@ -2235,8 +2212,9 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22352212 " is not an integer multiple of block_size " ,
22362213 block_size);
22372214 }
2238- broadcast_sparse_attention_mask (sparse_attention_mask, xt_block_size, block_size);
22392215 }
2216+ // keep original mask granularity; remember its block size for on-the-fly mapping
2217+ _helper._sparse_mask_block_size = xt_block_size;
22402218 }
22412219
22422220 _helper.init (H,
0 commit comments