[GPU] Add the capability for KV cache to update past KV #33114
[GPU] Add the capability for KV cache to update past KV #33114Kotomi-Du merged 43 commits intoopenvinotoolkit:masterfrom
Conversation
1. trigger trim flag when Slice pattern is matched 2. pass past_seq_len which is input data into trim info 3. store trim info in kv operator and kernel parameter 4. update input[0] and output[0] layout for trim
src/plugins/intel_gpu/src/kernel_selector/cl_kernels/reorder_kv_cache_ref.cl
Outdated
Show resolved
Hide resolved
...ugins/intel_gpu/src/kernel_selector/kernels/reorder_kv_cache/reorder_kv_cache_kernel_ref.cpp
Outdated
Show resolved
Hide resolved
|
build_jenkins |
|
build_jenkins |
|
is your pipeline only able to run on GPU not CPU? |
fix past_len checking during initialize
yes, our pipeline is required to be runnable on GPU. |
fix accuracy when there's no beam_idx add past_key_len handling
|
@Kotomi-Du please fix CI errors |
isanghao
left a comment
There was a problem hiding this comment.
LGTM, minor comments are left
| // readvalue --> any | ||
| // | | | ||
| // | v | ||
| // ------> kvcache |
There was a problem hiding this comment.
could you elaborate more why/how it can be optimized?
There was a problem hiding this comment.
If read_value is not optimized, we will get incorrect result among scatterelementupdate, so some change here is needed.
Original code is simply checking if readvalue is being used by single user, to be honest I don't know if it can prove anything --- that user could be actually a no-op with multiple further users.
From the comment in its caller, looks like it's actually trying to ensure assign will not impact any following user of readvalue, the original logic looks not very promising already.
Anyway, for our case, readvalue's user eventually need to pass kvcache before assign, which makes kvcache node the dominator of assign node, so it could be safely treated as if readvalue is directly connecting to kvcache, and could be optimized.
There was a problem hiding this comment.
Actually, my ask here was to add comment on "why/how". As it is not blocking code merge, could you follow-up as a separate PR?
isanghao
left a comment
There was a problem hiding this comment.
LGTM, could you check this comment? #33114 (comment)
mryzhov
left a comment
There was a problem hiding this comment.
Looks good from the Transformations perspective
Details:
This PR is to recognize the pattern of ScatterElementUpdate+Slice node(blue nodes in the picture below) and fuse them into multi-stages KVCache node. Besides, past_seq_len from onnx GQA which serves for correcting the length of KV Cache is missing in decomposition of onnx operator, it is added in the PR to make sure it is benefited from the new capability of KVCache.
After fusion, two related changes happened.
The picture below shows the graph changes before and after fusion.

Motivation and Context
The target application leverages tree-based speculative decoding to accelerate LLM inference. This technique requires frequent manipulation of past KV cache states (e.g. trimming, reordering). This is because only a single branch of the speculative draft tree is accepted after verification.
The current KV Cache API available is OV is very slow which cannot meet customer requirements. Details in CVS-174809. As OV team suggested, the only way to support reorder feature is to add specific nodes in the original graph. This PR is to recognize the pattern of added nodes and fuse them into multi-stages KVCache node to be more performant.
Tickets:
CVS-176367
Related PR
#32708