@@ -1743,7 +1743,7 @@ struct MHA {
17431743 score_output = _helper._score_output .template ptr <float >() + score_offset * _helper.H ;
17441744 }
17451745 }
1746-
1746+ // TODO: support second token sparse attention execution
17471747 _helper.exec_kernel_one_bh (
17481748 q.slice (0 , batch_in_token, batch_in_token),
17491749 k_cache,
@@ -1758,8 +1758,7 @@ struct MHA {
17581758 cur_kv_len,
17591759 alibi_slopes,
17601760 score_output,
1761- batch_in_seq,
1762- sparse_attention_mask);
1761+ batch_in_seq);
17631762 } else {
17641763 const auto batch_in_reorder = item.batch_in_reorder ;
17651764 const auto q_blk = item.q_block_id ;
@@ -1918,6 +1917,7 @@ struct MHA {
19181917 score_aggregation_window,
19191918 sparse_attention_mask);
19201919 } else {
1920+ // TODO: support second token sparse attention execution
19211921 _helper.exec_loop_bhl (query,
19221922 present_key,
19231923 present_value,
@@ -1929,8 +1929,7 @@ struct MHA {
19291929 block_indices,
19301930 block_indices_begins,
19311931 alibi_slopes,
1932- score_aggregation_window,
1933- sparse_attention_mask);
1932+ score_aggregation_window);
19341933 }
19351934 }
19361935};
@@ -2159,40 +2158,42 @@ struct AttentionExecutor : public PagedAttentionExecutor {
21592158 // TODO: enable block_size to be multiple of 32
21602159 OPENVINO_ASSERT (block_size == 32 , " CPU: block size must be 32, current: " , block_size);
21612160
2161+ // TODO: use the inputs values
2162+ // PlainTensor sum;
2163+ // PlainTensor mask;
2164+ size_t xt_stride = 16 ;
2165+ // The original block_size of the sparse attention mask;
2166+ size_t xt_block_size = 128 ;
2167+ // auto xt_block_size = 32;
2168+ float xt_threshold = 0 .9f ;
2169+
21622170 // --- Create and initialize sparse_attention_mask ---
21632171 sparse_attention_mask.clear ();
2164- size_t k_blocks = div_up (max_context_len, block_size);
2172+ // TODO: maybe use real context_len to save memory usage
2173+ size_t k_blocks = div_up (max_context_len, xt_block_size);
21652174 for (size_t b = 0 ; b < B_seq; ++b) {
21662175 auto q_len_for_batch = subsequence_begins.ptr <int32_t >()[b + 1 ] - subsequence_begins.ptr <int32_t >()[b];
2167- size_t q_blocks = div_up (static_cast <size_t >(q_len_for_batch), block_size );
2176+ size_t q_blocks = div_up (static_cast <size_t >(q_len_for_batch), xt_block_size );
21682177 PlainTensor mask;
21692178 mask.resize <bool >({H, q_blocks, k_blocks});
21702179 // Default initialize all to false
21712180 std::memset (mask.ptr <bool >(), 0 , H * q_blocks * k_blocks * sizeof (bool ));
2172- // Optional: activate all (example)
2173- for (size_t h = 0 ; h < H; ++h) {
2174- for (size_t q_blk = 0 ; q_blk < q_blocks; ++q_blk) {
2175- for (size_t k_blk = 0 ; k_blk < k_blocks; ++k_blk) {
2176- // At the symmetric point (q_blocks/2, k_blocks/2), set the upper-left and lower-right blocks
2177- // to true, others to false bool left_top = (q_blk < q_blocks / 2) && (k_blk < k_blocks / 2);
2178- // bool right_bottom = (q_blk >= (q_blocks + 1) / 2) && (k_blk >= (k_blocks + 1) / 2);
2179- // mask.ptr<bool>(h, q_blk, k_blk)[0] = left_top || right_bottom;
2180-
2181- // All masks are set to true
2182- mask.ptr <bool >(h, q_blk, k_blk)[0 ] = true ;
2183-
2184- // Set the middle block to false
2185- // if (q_blk == q_blocks / 2 && k_blk == k_blocks / 2) {
2186- // mask.ptr<bool>(h, q_blk, k_blk)[0] = false;
2187- // } else {
2188- // mask.ptr<bool>(h, q_blk, k_blk)[0] = true;
2189- // }
2190- }
2191- }
2192- }
21932181 sparse_attention_mask.push_back (std::move (mask));
21942182 }
21952183
2184+ // TODO: Avoid temporary vector and assignment here. Consider changing get_sparse_blocks
2185+ // to take `std::vector<PlainTensor>& sparse_attention_mask` and fill it in-place
2186+ // to reduce PlainTensor copies (memcpy of metadata) and allocations.
2187+ sparse_attention_mask = get_sparse_blocks (q,
2188+ k,
2189+ past_lens,
2190+ subsequence_begins,
2191+ block_indices,
2192+ block_indices_begins,
2193+ xt_stride,
2194+ xt_block_size,
2195+ xt_threshold);
2196+
21962197 // --- Broadcast sparse_attention_mask to support different block sizes ---
21972198 // The input mask may be [h, q_blocks_orig, k_blocks_orig], and needs to be broadcast to [h, q_blocks, k_blocks]
21982199 auto broadcast_sparse_attention_mask =
@@ -2233,20 +2234,19 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22332234 };
22342235 // The original block_size of the sparse attention mask; can be specified later via the Page Attention Node
22352236 // parameter const size_t sparse_attention_BlockSize = 128;
2236- const size_t sparse_attention_BlockSize = 32 ;
22372237 // Only support block_size <= sparse_attention_BlockSize and sparse_attention_BlockSize must be an integer
22382238 // multiple
2239- if (block_size != sparse_attention_BlockSize ) {
2240- if (block_size > sparse_attention_BlockSize ) {
2241- OPENVINO_THROW (" not supported: block_size > sparse_attention_BlockSize " );
2239+ if (block_size != xt_block_size ) {
2240+ if (block_size > xt_block_size ) {
2241+ OPENVINO_THROW (" not supported: block_size > xt_block_size " );
22422242 }
2243- if (sparse_attention_BlockSize % block_size != 0 ) {
2244- OPENVINO_THROW (" not supported: sparse_attention_BlockSize " ,
2245- sparse_attention_BlockSize ,
2243+ if (xt_block_size % block_size != 0 ) {
2244+ OPENVINO_THROW (" not supported: xt_block_size " ,
2245+ xt_block_size ,
22462246 " is not an integer multiple of block_size " ,
22472247 block_size);
22482248 }
2249- broadcast_sparse_attention_mask (sparse_attention_mask, sparse_attention_BlockSize , block_size);
2249+ broadcast_sparse_attention_mask (sparse_attention_mask, xt_block_size , block_size);
22502250 }
22512251
22522252 _helper.init (H,
@@ -2314,6 +2314,27 @@ struct AttentionExecutor : public PagedAttentionExecutor {
23142314 }
23152315 }
23162316
2317+ std::vector<PlainTensor> get_sparse_blocks (PlainTensor& q,
2318+ PlainTensor& k,
2319+ PlainTensor& past_lens,
2320+ PlainTensor& subsequence_begins,
2321+ PlainTensor& block_indices,
2322+ PlainTensor& block_indices_begins,
2323+ size_t x_attention_stride,
2324+ size_t x_attention_block_size,
2325+ float threshold) {
2326+ size_t num_seqs = past_lens.size (0 );
2327+ std::vector<PlainTensor> masks (num_seqs);
2328+
2329+ // TODO: support multiple batches
2330+ for (size_t seq_idx = 0 ; seq_idx < 1 ; seq_idx++) {
2331+ if (q.size (0 ) > 1 ) {
2332+ masks[seq_idx] = xattn_estimate (q, k, x_attention_block_size, x_attention_stride, 1 , threshold, true );
2333+ }
2334+ }
2335+ return masks;
2336+ }
2337+
23172338 void execute (const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr> outputs) override {
23182339 PlainTensor q;
23192340 PlainTensor k;
@@ -2371,16 +2392,6 @@ struct AttentionExecutor : public PagedAttentionExecutor {
23712392 output_score,
23722393 sparse_attention_mask);
23732394
2374- PlainTensor sum;
2375- PlainTensor mask;
2376- auto stride = 16 ;
2377- auto block_size = 128 ;
2378- auto threshold = 0 .9f ;
2379-
2380- if (q.size (0 ) > 1 ) {
2381- mask = xattn_estimate (q, k, block_size, stride, 1 , threshold, true );
2382- }
2383-
23842395 if (rotated_block_indices) {
23852396 // Rotate kv cache currently doesn't support quantized cache.
23862397 // for u8 it only supports compilation but throws exception in the runtime
0 commit comments