Skip to content

[GPU] Enable multi head size support for KV cache #29936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

clee30
Copy link
Contributor

@clee30 clee30 commented Apr 4, 2025

In continue batching, head size for key and value will be different. Add support for it for sdpa.

Tickets:

CVS-162339 and CVS-161089

@clee30 clee30 requested review from a team as code owners April 4, 2025 07:24
@github-actions github-actions bot added the category: GPU OpenVINO GPU plugin label Apr 4, 2025
@sys-openvino-ci sys-openvino-ci added the ExternalIntelPR External contributor from Intel label Apr 4, 2025
@p-durandin
Copy link
Contributor

build_jenkins

@sshlyapn sshlyapn added this to the 2025.2 milestone Apr 4, 2025
@clee30 clee30 force-pushed the kv_multiheadsize branch from 4ed2f07 to 9e5f1a8 Compare April 4, 2025 08:41
@p-durandin
Copy link
Contributor

build_jenkins

@clee30 clee30 force-pushed the kv_multiheadsize branch from 9e5f1a8 to 0960ea1 Compare April 7, 2025 02:08
@p-durandin
Copy link
Contributor

build_jenkins

@clee30 clee30 closed this Apr 8, 2025
@clee30 clee30 force-pushed the kv_multiheadsize branch from 3c5d405 to d119656 Compare April 8, 2025 10:02
@clee30 clee30 reopened this Apr 8, 2025
@p-durandin
Copy link
Contributor

build_jenkins

@clee30
Copy link
Contributor Author

clee30 commented Apr 9, 2025

Found regression issue when running qwen2-7b with paged_attention. Need to check on it.

@clee30 clee30 force-pushed the kv_multiheadsize branch from d44fd0a to dd777a0 Compare April 9, 2025 08:29
@p-durandin
Copy link
Contributor

build_jenkins

@yeonbok
Copy link
Contributor

yeonbok commented Apr 10, 2025

Is the regression issue resolved? If not, please add a label of "Do not merge" or "Under perf check"

@clee30
Copy link
Contributor Author

clee30 commented Apr 10, 2025

Is the regression issue resolved? If not, please add a label of "Do not merge" or "Under perf check"

Yes, it is solved.

@p-durandin
Copy link
Contributor

build_jenkins

@clee30 clee30 force-pushed the kv_multiheadsize branch from 77bff17 to 8e62237 Compare April 14, 2025 08:36
@p-durandin
Copy link
Contributor

build_jenkins

In continue batching, head size for key and value will be different.
Add support for it for sdpa.
@clee30 clee30 force-pushed the kv_multiheadsize branch from 94b1f5c to afb86d6 Compare April 14, 2025 09:14
@clee30 clee30 requested a review from a team as a code owner April 14, 2025 09:14
@github-actions github-actions bot added the category: C API OpenVINO C API bindings label Apr 14, 2025
@p-durandin
Copy link
Contributor

build_jenkins

@clee30 clee30 force-pushed the kv_multiheadsize branch from afb86d6 to 782a845 Compare April 16, 2025 03:27
@github-actions github-actions bot removed the category: C API OpenVINO C API bindings label Apr 16, 2025
@mlukasze mlukasze dismissed their stale review April 16, 2025 04:44

request fulfilled, blocking review removed

@mlukasze
Copy link
Contributor

build_jenkins

@mlukasze mlukasze self-requested a review April 16, 2025 04:45
@@ -35,7 +35,8 @@ struct paged_attention : public primitive_base<paged_attention> {

auto rhs_casted = downcast<const paged_attention>(rhs);

return head_size == rhs_casted.head_size &&
return k_head_size == rhs_casted.k_head_size &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[random spot]
Please add SDPA tests for different key/value head_sizes combinations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated paged_attention_gpu_test.cpp unit tests that will test sdpa with different key/value head size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@clee30, let's update SDPA tests as well to cover more cases and kernels, not limited with PA operation only, we have:

  1. SDPA single layer test
  2. SDPA subgraph tests with KV cache operation executing multiple iterations
    Please update them with v_head_size != k_head_size cases

@clee30 clee30 force-pushed the kv_multiheadsize branch from a2d1194 to c78fe42 Compare April 17, 2025 13:57
Besides, fix functional test for sdpa also.
@clee30 clee30 force-pushed the kv_multiheadsize branch from c78fe42 to f2a7f59 Compare April 17, 2025 14:30
@p-durandin
Copy link
Contributor

build_jenkins

Comment on lines +29 to +31
auto data_shape = { q_layout.get_partial_shape()[0],
ov::Dimension(desc->heads_num),
ov::Dimension(desc->v_head_size) };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PagedAttention has 2D output [batch_size_in_tokens, num_heads * head_size], so it should be like [q_layout.get_partial_shape()[0], desc->heads_num * desc->v_head_size]

@@ -246,7 +246,7 @@ KERNEL(sdpa_opt)(
// Main Gemm1 calculation loop
// Each SG performs element-wise multiplications of Q[HEAD_SIZE]xK[HEAD_SIZE] values
// HEAD_SIZE / SUBGROUPS_PER_WG times in the loop and saves the result to the qk_local SLM buffer
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) {
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) {
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG {

Comment on lines 226 to 230
#if SG_SCALE_FACTOR == 2
if (sgid < HEAD_SIZE / SUBGROUP_SIZE) {
if (sgid < V_HEAD_SIZE / SUBGROUP_SIZE) {
#else
{
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if v_head_size is smaller than k_head_size? Then we don't have enough subgroups to load all the required k_head_size query inputs

@@ -409,7 +409,7 @@ KERNEL(sdpa_opt)(
const uint seq_idx_end = 1;
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) {
// Iterate over all values QK values in SLM and apply scale and attention mask
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE * SG_SCALE_FACTOR)) {
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE * SG_SCALE_FACTOR)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE * SG_SCALE_FACTOR)) {
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG * SUBGROUP_SIZE) {

@@ -908,24 +908,26 @@ KERNEL(sdpa_opt)(
#else
const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE);
#endif
const uint num_read_blocks = K_HEAD_SIZE == V_HEAD_SIZE ? 1 : (K_HEAD_SIZE - V_HEAD_SIZE) / TARGET_SEQ_LEN_BLOCK_SIZE + SG_SCALE_FACTOR;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If v_head_size=128 and k_head_size=64, should this value be negative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a case where v_head_size > k_head_size?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but it seems there is no reasons why it can't happen
I'd suggest to implement support for general k_head_size != v_head_size case without assuming any relations between them

Comment on lines 70 to +71
int head_size;
int v_head_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
Please replace all occurrences of head_size -> k_head_size in this test

Suggested change
int head_size;
int v_head_size;
int k_head_size;
int v_head_size;

Comment on lines +78 to +80
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, head_size, v_head_size], ..}
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..}
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, head_size, v_head_size], ..}
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..}
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..}
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, k_head_size], ..}
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..}
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: GPU OpenVINO GPU plugin ExternalIntelPR External contributor from Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants