@@ -889,17 +889,22 @@ struct MHAHelper {
889889 }
890890 }
891891
892- // sparse attention mask: set score to -inf for (q_blk, k_blk) where
892+ // Instead of writing -inf directly into scores, build a softmax mask (0/-inf) and pass it to the kernel
893+ DATA_TYPE* softmax_mask = nullptr ;
894+ std::vector<DATA_TYPE> softmax_mask_storage;
893895 if (!sparse_attention_mask.empty ()) {
894- float * score_base = _weight.ptr <float >(ithr, h - hq_beg, 0 );
895- for (size_t k_blk = 0 ; k_blk < cur_kv_len_blocks; k_blk++) {
896- if (!sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
897- for (size_t m = 0 ; m < q_cnt; m++) {
898- float * score_blk = score_base + m * cur_kv_len_blocks * _block_size + k_blk * _block_size;
899- std::fill (score_blk, score_blk + _block_size, -std::numeric_limits<float >::infinity ());
900- }
896+ const size_t padded_len = rnd_up (cur_kv_len, _block_size);
897+ softmax_mask_storage.resize (padded_len);
898+ // Initialize to -inf by default; then set positions for allowed blocks to 0
899+ const DATA_TYPE neg_inf_val = static_cast <DATA_TYPE>(-std::numeric_limits<float >::infinity ());
900+ std::fill (softmax_mask_storage.begin (), softmax_mask_storage.end (), neg_inf_val);
901+ for (size_t k = 0 ; k < cur_kv_len; ++k) {
902+ size_t k_blk = k / _block_size;
903+ if (sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk, k_blk)[0 ]) {
904+ softmax_mask_storage[k] = static_cast <DATA_TYPE>(0 );
901905 }
902906 }
907+ softmax_mask = softmax_mask_storage.data ();
903908 }
904909
905910 for (size_t m = q_start; m < q_end; m++) {
@@ -923,7 +928,7 @@ struct MHAHelper {
923928 reinterpret_cast <DATA_TYPE*>(score) + start_idx,
924929 revised_d_scale,
925930 alibi_lookup,
926- nullptr ,
931+ reinterpret_cast < void *>(softmax_mask + start_idx) ,
927932 nullptr ,
928933 false ,
929934 new_causal,
@@ -944,7 +949,7 @@ struct MHAHelper {
944949 reinterpret_cast <DATA_TYPE*>(score),
945950 revised_d_scale,
946951 alibi_lookup,
947- nullptr ,
952+ reinterpret_cast < void *>(softmax_mask) ,
948953 nullptr ,
949954 false ,
950955 ncausal,
@@ -1856,7 +1861,7 @@ struct MHA {
18561861 score_info_ptr,
18571862 batch_in_seq,
18581863 sparse_attention_mask);
1859- # endif
1864+ # endif
18601865 }
18611866 });
18621867 if (output_score) {
@@ -2158,95 +2163,80 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21582163 // TODO: enable block_size to be multiple of 32
21592164 OPENVINO_ASSERT (block_size == 32 , " CPU: block size must be 32, current: " , block_size);
21602165
2161- // TODO: use the inputs values
2162- // PlainTensor sum;
2163- // PlainTensor mask;
21642166 size_t xt_stride = 16 ;
21652167 // The original block_size of the sparse attention mask;
21662168 size_t xt_block_size = 128 ;
21672169 // auto xt_block_size = 32;
2168- float xt_threshold = 0 .9f ;
2169-
2170- // --- Create and initialize sparse_attention_mask ---
2171- sparse_attention_mask.clear ();
2172- // TODO: maybe use real context_len to save memory usage
2173- size_t k_blocks = div_up (max_context_len, xt_block_size);
2174- for (size_t b = 0 ; b < B_seq; ++b) {
2175- auto q_len_for_batch = subsequence_begins.ptr <int32_t >()[b + 1 ] - subsequence_begins.ptr <int32_t >()[b];
2176- size_t q_blocks = div_up (static_cast <size_t >(q_len_for_batch), xt_block_size);
2177- PlainTensor mask;
2178- mask.resize <bool >({H, q_blocks, k_blocks});
2179- // Default initialize all to false
2180- std::memset (mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
2181- sparse_attention_mask.push_back (std::move (mask));
2182- }
2183-
2184- // TODO: Avoid temporary vector and assignment here. Consider changing get_sparse_blocks
2185- // to take `std::vector<PlainTensor>& sparse_attention_mask` and fill it in-place
2186- // to reduce PlainTensor copies (memcpy of metadata) and allocations.
2187- sparse_attention_mask = get_sparse_blocks (q,
2188- k,
2189- past_lens,
2190- subsequence_begins,
2191- block_indices,
2192- block_indices_begins,
2193- xt_stride,
2194- xt_block_size,
2195- xt_threshold);
2196-
2197- // --- Broadcast sparse_attention_mask to support different block sizes ---
2198- // The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks, k_blocks]
2199- auto broadcast_sparse_attention_mask =
2200- [](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2201- if (src_block_size == dst_block_size)
2202- return ;
2203- if (src_block_size % dst_block_size != 0 ) {
2204- OPENVINO_THROW (" not supported: sparse_attention_BlockSize=" ,
2205- src_block_size,
2206- " is not an integer multiple of block_size=" ,
2207- dst_block_size);
2208- }
2209- size_t scale = src_block_size / dst_block_size;
2210- for (auto & mask : mask_vec) {
2211- auto shape = mask.shape ();
2212- size_t H = shape[0 ];
2213- size_t q_blocks_orig = shape[1 ];
2214- size_t k_blocks_orig = shape[2 ];
2215- size_t q_blocks = q_blocks_orig * scale;
2216- size_t k_blocks = k_blocks_orig * scale;
2217- PlainTensor new_mask;
2218- new_mask.resize <bool >({H, q_blocks, k_blocks});
2219- std::memset (new_mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
2220- for (size_t h = 0 ; h < H; ++h) {
2221- for (size_t q_blk = 0 ; q_blk < q_blocks_orig; ++q_blk) {
2222- for (size_t k_blk = 0 ; k_blk < k_blocks_orig; ++k_blk) {
2223- bool val = mask.ptr <bool >(h, q_blk, k_blk)[0 ];
2224- for (size_t dq = 0 ; dq < scale; ++dq) {
2225- for (size_t dk = 0 ; dk < scale; ++dk) {
2226- new_mask.ptr <bool >(h, q_blk * scale + dq, k_blk * scale + dk)[0 ] = val;
2170+ float xt_threshold = 0 .6f ;
2171+ // float xt_threshold = 1.0f;
2172+
2173+ // If to support second token sparse attention, need generate sparse mask after concat_pastkv
2174+ if (q.size (0 ) > 1 ) {
2175+ sparse_attention_mask = get_sparse_blocks (q,
2176+ k,
2177+ past_lens,
2178+ subsequence_begins,
2179+ block_indices,
2180+ block_indices_begins,
2181+ xt_stride,
2182+ xt_block_size,
2183+ xt_threshold);
2184+
2185+ // --- Broadcast sparse_attention_mask to support different block sizes ---
2186+ // The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks,
2187+ // k_blocks]
2188+ auto broadcast_sparse_attention_mask =
2189+ [](std::vector<PlainTensor>& mask_vec, size_t src_block_size, size_t dst_block_size) {
2190+ if (src_block_size == dst_block_size)
2191+ return ;
2192+ if (src_block_size % dst_block_size != 0 ) {
2193+ OPENVINO_THROW (" not supported: sparse_attention_BlockSize=" ,
2194+ src_block_size,
2195+ " is not an integer multiple of block_size=" ,
2196+ dst_block_size);
2197+ }
2198+ size_t scale = src_block_size / dst_block_size;
2199+ for (auto & mask : mask_vec) {
2200+ auto shape = mask.shape ();
2201+ size_t H = shape[0 ];
2202+ size_t q_blocks_orig = shape[1 ];
2203+ size_t k_blocks_orig = shape[2 ];
2204+ size_t q_blocks = q_blocks_orig * scale;
2205+ size_t k_blocks = k_blocks_orig * scale;
2206+ PlainTensor new_mask;
2207+ new_mask.resize <bool >({H, q_blocks, k_blocks});
2208+ std::memset (new_mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
2209+ for (size_t h = 0 ; h < H; ++h) {
2210+ for (size_t q_blk = 0 ; q_blk < q_blocks_orig; ++q_blk) {
2211+ for (size_t k_blk = 0 ; k_blk < k_blocks_orig; ++k_blk) {
2212+ bool val = mask.ptr <bool >(h, q_blk, k_blk)[0 ];
2213+ for (size_t dq = 0 ; dq < scale; ++dq) {
2214+ for (size_t dk = 0 ; dk < scale; ++dk) {
2215+ new_mask.ptr <bool >(h, q_blk * scale + dq, k_blk * scale + dk)[0 ] = val;
2216+ }
22272217 }
22282218 }
22292219 }
22302220 }
2221+ mask = std::move (new_mask);
22312222 }
2232- mask = std::move (new_mask);
2223+ };
2224+ // The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2225+ // parameter const size_t sparse_attention_BlockSize = 128;
2226+ // Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
2227+ // multiple
2228+ if (block_size != xt_block_size) {
2229+ if (block_size > xt_block_size) {
2230+ OPENVINO_THROW (" not supported: block_size > xt_block_size" );
22332231 }
2234- };
2235- // The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
2236- // parameter const size_t sparse_attention_BlockSize = 128;
2237- // Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
2238- // multiple
2239- if (block_size != xt_block_size) {
2240- if (block_size > xt_block_size) {
2241- OPENVINO_THROW (" not supported: block_size > xt_block_size" );
2242- }
2243- if (xt_block_size % block_size != 0 ) {
2244- OPENVINO_THROW (" not supported: xt_block_size " ,
2245- xt_block_size,
2246- " is not an integer multiple of block_size " ,
2247- block_size);
2232+ if (xt_block_size % block_size != 0 ) {
2233+ OPENVINO_THROW (" not supported: xt_block_size " ,
2234+ xt_block_size,
2235+ " is not an integer multiple of block_size " ,
2236+ block_size);
2237+ }
2238+ broadcast_sparse_attention_mask (sparse_attention_mask, xt_block_size, block_size);
22482239 }
2249- broadcast_sparse_attention_mask (sparse_attention_mask, xt_block_size, block_size);
22502240 }
22512241
22522242 _helper.init (H,
0 commit comments