-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[GPU] Enable multi head size support for KV cache #29936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
build_jenkins |
build_jenkins |
build_jenkins |
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
build_jenkins |
Found regression issue when running qwen2-7b with paged_attention. Need to check on it. |
build_jenkins |
Is the regression issue resolved? If not, please add a label of "Do not merge" or "Under perf check" |
Yes, it is solved. |
build_jenkins |
77bff17
to
8e62237
Compare
build_jenkins |
In continue batching, head size for key and value will be different. Add support for it for sdpa.
94b1f5c
to
afb86d6
Compare
build_jenkins |
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_micro.cpp
Outdated
Show resolved
Hide resolved
afb86d6
to
782a845
Compare
request fulfilled, blocking review removed
build_jenkins |
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Outdated
Show resolved
Hide resolved
@@ -35,7 +35,8 @@ struct paged_attention : public primitive_base<paged_attention> { | |||
|
|||
auto rhs_casted = downcast<const paged_attention>(rhs); | |||
|
|||
return head_size == rhs_casted.head_size && | |||
return k_head_size == rhs_casted.k_head_size && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[random spot]
Please add SDPA tests for different key/value head_sizes combinations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have updated paged_attention_gpu_test.cpp unit tests that will test sdpa with different key/value head size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@clee30, let's update SDPA tests as well to cover more cases and kernels, not limited with PA operation only, we have:
- SDPA single layer test
- SDPA subgraph tests with KV cache operation executing multiple iterations
Please update them with v_head_size != k_head_size cases
a2d1194
to
c78fe42
Compare
Besides, fix functional test for sdpa also.
c78fe42
to
f2a7f59
Compare
build_jenkins |
auto data_shape = { q_layout.get_partial_shape()[0], | ||
ov::Dimension(desc->heads_num), | ||
ov::Dimension(desc->v_head_size) }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PagedAttention has 2D output [batch_size_in_tokens, num_heads * head_size], so it should be like [q_layout.get_partial_shape()[0], desc->heads_num * desc->v_head_size]
@@ -246,7 +246,7 @@ KERNEL(sdpa_opt)( | |||
// Main Gemm1 calculation loop | |||
// Each SG performs element-wise multiplications of Q[HEAD_SIZE]xK[HEAD_SIZE] values | |||
// HEAD_SIZE / SUBGROUPS_PER_WG times in the loop and saves the result to the qk_local SLM buffer | |||
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) { | |||
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE / SUBGROUP_SIZE) * SG_SCALE_FACTOR) { | |
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG { |
#if SG_SCALE_FACTOR == 2 | ||
if (sgid < HEAD_SIZE / SUBGROUP_SIZE) { | ||
if (sgid < V_HEAD_SIZE / SUBGROUP_SIZE) { | ||
#else | ||
{ | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if v_head_size is smaller than k_head_size? Then we don't have enough subgroups to load all the required k_head_size query inputs
@@ -409,7 +409,7 @@ KERNEL(sdpa_opt)( | |||
const uint seq_idx_end = 1; | |||
for (uint seq_idx = 0; seq_idx < seq_idx_end; seq_idx++) { | |||
// Iterate over all values QK values in SLM and apply scale and attention mask | |||
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE * SG_SCALE_FACTOR)) { | |||
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE * SG_SCALE_FACTOR)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += (V_HEAD_SIZE * SG_SCALE_FACTOR)) { | |
for (uint seq_len = sgid * SUBGROUP_SIZE + sglid; seq_len < partition_seq_len; seq_len += SUBGROUPS_PER_WG * SUBGROUP_SIZE) { |
@@ -908,24 +908,26 @@ KERNEL(sdpa_opt)( | |||
#else | |||
const uint seq_idx_end = min(TARGET_SEQ_LEN - target_seq_idx, (uint)TARGET_SEQ_LEN_BLOCK_SIZE); | |||
#endif | |||
const uint num_read_blocks = K_HEAD_SIZE == V_HEAD_SIZE ? 1 : (K_HEAD_SIZE - V_HEAD_SIZE) / TARGET_SEQ_LEN_BLOCK_SIZE + SG_SCALE_FACTOR; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If v_head_size=128 and k_head_size=64, should this value be negative?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a case where v_head_size > k_head_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, but it seems there is no reasons why it can't happen
I'd suggest to implement support for general k_head_size != v_head_size case without assuming any relations between them
int head_size; | ||
int v_head_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
Please replace all occurrences of head_size
-> k_head_size
in this test
int head_size; | |
int v_head_size; | |
int k_head_size; | |
int v_head_size; |
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, head_size, v_head_size], ..} | ||
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..} | ||
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, head_size, v_head_size], ..} | |
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..} | |
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, head_size, v_head_size], ..} | |
std::vector<std::vector<ov::float16>> query_data; // {[1, num_tokens, num_heads, k_head_size], ..} | |
std::vector<std::vector<ov::float16>> key_data; // {[1, past_len + num_tokens, num_heads, k_head_size], ..} | |
std::vector<std::vector<ov::float16>> value_data; // {[1, past_len + num_tokens, num_heads, v_head_size], ..} |
In continue batching, head size for key and value will be different. Add support for it for sdpa.
Tickets:
CVS-162339 and CVS-161089