Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5a95cc6
fuse GQA slice node into kvCache for in-place crop
Kotomi-Du Nov 25, 2025
84d8095
fix conformance issue
Kotomi-Du Dec 2, 2025
ab5be7c
Use RemoteTensor to reorder KV cache
mdvoretc-intel Nov 17, 2025
a0ed479
Add a kernel to reorder KV cache
mdvoretc-intel Nov 19, 2025
1c72628
Add KVCache index fusion for reorder
mdvoretc-intel Nov 20, 2025
4d67227
Fix basic issues
mdvoretc-intel Nov 21, 2025
ff94cfe
Prevent KV reorder execution for cases where it's not required
mdvoretc-intel Nov 21, 2025
c85841f
Fix scalar arguments bug, remove debug prints
mdvoretc-intel Nov 27, 2025
817c983
Remove unused gather_by_axis code
mdvoretc-intel Dec 2, 2025
b9d5f30
Fix input offsets
mdvoretc-intel Dec 3, 2025
07c75d8
Add unit test case
mdvoretc-intel Dec 3, 2025
1586380
Add feature bounds check
mdvoretc-intel Dec 3, 2025
d338bcc
clean up code
Kotomi-Du Dec 4, 2025
648a5fc
clean up execution stage
Kotomi-Du Dec 5, 2025
d7043fe
use scatterElementUpdate kernel instead of self customized kernel
Kotomi-Du Dec 5, 2025
145e0f5
delete customized kernel path
Kotomi-Du Dec 5, 2025
e01cedf
clean up code
Kotomi-Du Dec 6, 2025
4c2d73a
fix code style
Kotomi-Du Dec 10, 2025
9db4cc7
adjust index for compressed KV stage when update_kv stage is existed
Kotomi-Du Dec 10, 2025
73a739e
refactor tests, merge duplicated code
ZackyLake Dec 18, 2025
7914ddd
refactor kvcache stage.
ZackyLake Dec 18, 2025
8e74647
remove update_kv logic on compress kv.
ZackyLake Dec 30, 2025
9bfb862
remove indirect support on kv_update due to lack of test.
ZackyLake Dec 30, 2025
5326963
add debug priont for skipped kernel
ZackyLake Jan 7, 2026
f321f9b
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 7, 2026
ea925ed
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 7, 2026
6cc91a1
fix kv fusion pattern。
ZackyLake Jan 10, 2026
561bc54
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 10, 2026
d836186
fix test
ZackyLake Jan 12, 2026
6344883
move trim_length to kv_cache_inst.h
Kotomi-Du Jan 7, 2026
4d06b8e
fix kv fusion for test(stridedslice)
ZackyLake Jan 15, 2026
644cc75
include trim-only support
ZackyLake Jan 15, 2026
abcdf72
fix concat_axis signness
ZackyLake Jan 15, 2026
a1ecd57
fix fusion logic
ZackyLake Jan 15, 2026
f8f1a58
fix signedness
ZackyLake Jan 15, 2026
3d65d54
allow trim on indirect kvcache.
ZackyLake Jan 16, 2026
313a748
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 16, 2026
c429feb
Make CompressedKV compatible with trim.
ZackyLake Jan 17, 2026
6781639
fix
ZackyLake Jan 17, 2026
6a3bbd2
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 17, 2026
3df8bac
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 18, 2026
765e167
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 19, 2026
0dc75bc
add comment
ZackyLake Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ struct kernel_impl_params final {
std::vector<size_t> output_size;
std::vector<size_t> img_size;

// KV cache trim length - set at runtime during shape inference
// Marked as mutable to allow modification even when kernel_impl_params is passed as const reference
mutable int64_t kv_cache_trim_length = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a right place to add such kv-cache specific field.

Copy link
Contributor Author

@Kotomi-Du Kotomi-Du Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest any other place for putting this? Here is the investigation on our side.

The table below shows all the available KVCache related files, but none of them are suitable to put this parameter which requires to be updated in each iteration in runtime.

Specifically, for kv_cache_inst.h, kv_cache_inst::trim_length couldn't be updated in static function calc_output_layout() . It also doesn't make sense to set kv_cache_inst::trim_length as static to make it work, because it will lead to data race across multi-kv-instances or multi-threads.

Furthermore, kernel_imp_params.h also includes other op-specific variables with TODO comment (prior-box).
So, it seems acceptable in our case as well.

Description File Path Notes
ov::intel_gpu::op::KVCache src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp it cannot server for runtime
cldnn::kv_cache src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp it stores static attributes and can be serialized into the kernel's binary cache
cldnn::typed_primitive_inst<kv_cache> src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h node instance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about introducing separate method in kv_cache_inst.h? You can make a non-static method and use it to store information in primitive_inst. Then this API can be just called after shape inference. You should not place this field in kernel_impl_param.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated implementation looks good for trim_length.


std::map<size_t, size_t> in_port_to_shape_info_offset = {};
std::map<size_t, size_t> out_port_to_shape_info_offset = {};

Expand Down
17 changes: 17 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

/// KVcache with reorder for tree-based speculative decoding
KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& past_seq_len,
const Output<Node>& dst_idx,
const Output<Node>& update_data,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
const ov::element::Type output_type = ov::element::dynamic);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
Expand All @@ -51,6 +61,11 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }
bool get_update_kv() const { return m_update_kv; }
void set_update_kv(bool update_kv) { m_update_kv = update_kv; }

uint64_t get_trim_length() const { return m_trim_length; }
void set_trim_length(uint64_t trim_length) { m_trim_length = trim_length; }

protected:
KVCache(const OutputVector& inputs,
Expand All @@ -63,6 +78,8 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;
bool m_update_kv = false;
uint64_t m_trim_length = 0;

ov::element::Type m_output_type;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ struct kv_cache : public primitive_base<kv_cache> {
const ov::op::util::VariableInfo& variable_info,
const int64_t concat_axis,
const int64_t gather_axis,
const bool indirect)
const bool indirect,
const bool update_kv)
: primitive_base(id, inputs)
, variable_info(variable_info)
, concat_axis(concat_axis)
, gather_axis(gather_axis)
, indirect(indirect) {}
, indirect(indirect)
, update_kv(update_kv) {}

ov::op::util::VariableInfo variable_info;
int64_t concat_axis = 0;
int64_t gather_axis = 0;
bool indirect = false;

bool update_kv = false;
bool compressed = false;
QuantizationAttributes quantization_attributes;

Expand All @@ -47,6 +49,7 @@ struct kv_cache : public primitive_base<kv_cache> {
seed = hash_combine(seed, concat_axis);
seed = hash_combine(seed, gather_axis);
seed = hash_combine(seed, indirect);
seed = hash_combine(seed, update_kv);
seed = hash_combine(seed, compressed);
seed = hash_range(seed, quantization_attributes.scales_zp_output_order.begin(), quantization_attributes.scales_zp_output_order.end());
seed = hash_range(seed, quantization_attributes.group_sizes.begin(), quantization_attributes.group_sizes.end());
Expand All @@ -69,6 +72,7 @@ struct kv_cache : public primitive_base<kv_cache> {
concat_axis == rhs_casted.concat_axis &&
gather_axis == rhs_casted.gather_axis &&
indirect == rhs_casted.indirect &&
update_kv == rhs_casted.update_kv &&
compressed == rhs_casted.compressed &&
quantization_attributes.scales_zp_output_order == rhs_casted.quantization_attributes.scales_zp_output_order &&
quantization_attributes.output_storage_type == rhs_casted.quantization_attributes.output_storage_type &&
Expand All @@ -88,6 +92,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ob << concat_axis;
ob << gather_axis;
ob << indirect;
ob << update_kv;
ob << compressed;
ob << make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ob << make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand All @@ -110,6 +115,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ib >> concat_axis;
ib >> gather_axis;
ib >> indirect;
ib >> update_kv;
ib >> compressed;
ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand Down
Loading
Loading