Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5a95cc6
fuse GQA slice node into kvCache for in-place crop
Kotomi-Du Nov 25, 2025
84d8095
fix conformance issue
Kotomi-Du Dec 2, 2025
ab5be7c
Use RemoteTensor to reorder KV cache
mdvoretc-intel Nov 17, 2025
a0ed479
Add a kernel to reorder KV cache
mdvoretc-intel Nov 19, 2025
1c72628
Add KVCache index fusion for reorder
mdvoretc-intel Nov 20, 2025
4d67227
Fix basic issues
mdvoretc-intel Nov 21, 2025
ff94cfe
Prevent KV reorder execution for cases where it's not required
mdvoretc-intel Nov 21, 2025
c85841f
Fix scalar arguments bug, remove debug prints
mdvoretc-intel Nov 27, 2025
817c983
Remove unused gather_by_axis code
mdvoretc-intel Dec 2, 2025
b9d5f30
Fix input offsets
mdvoretc-intel Dec 3, 2025
07c75d8
Add unit test case
mdvoretc-intel Dec 3, 2025
1586380
Add feature bounds check
mdvoretc-intel Dec 3, 2025
d338bcc
clean up code
Kotomi-Du Dec 4, 2025
648a5fc
clean up execution stage
Kotomi-Du Dec 5, 2025
d7043fe
use scatterElementUpdate kernel instead of self customized kernel
Kotomi-Du Dec 5, 2025
145e0f5
delete customized kernel path
Kotomi-Du Dec 5, 2025
e01cedf
clean up code
Kotomi-Du Dec 6, 2025
4c2d73a
fix code style
Kotomi-Du Dec 10, 2025
9db4cc7
adjust index for compressed KV stage when update_kv stage is existed
Kotomi-Du Dec 10, 2025
73a739e
refactor tests, merge duplicated code
ZackyLake Dec 18, 2025
7914ddd
refactor kvcache stage.
ZackyLake Dec 18, 2025
8e74647
remove update_kv logic on compress kv.
ZackyLake Dec 30, 2025
9bfb862
remove indirect support on kv_update due to lack of test.
ZackyLake Dec 30, 2025
5326963
add debug priont for skipped kernel
ZackyLake Jan 7, 2026
f321f9b
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 7, 2026
ea925ed
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 7, 2026
6cc91a1
fix kv fusion pattern。
ZackyLake Jan 10, 2026
561bc54
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 10, 2026
d836186
fix test
ZackyLake Jan 12, 2026
6344883
move trim_length to kv_cache_inst.h
Kotomi-Du Jan 7, 2026
4d06b8e
fix kv fusion for test(stridedslice)
ZackyLake Jan 15, 2026
644cc75
include trim-only support
ZackyLake Jan 15, 2026
abcdf72
fix concat_axis signness
ZackyLake Jan 15, 2026
a1ecd57
fix fusion logic
ZackyLake Jan 15, 2026
f8f1a58
fix signedness
ZackyLake Jan 15, 2026
3d65d54
allow trim on indirect kvcache.
ZackyLake Jan 16, 2026
313a748
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 16, 2026
c429feb
Make CompressedKV compatible with trim.
ZackyLake Jan 17, 2026
6781639
fix
ZackyLake Jan 17, 2026
6a3bbd2
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 17, 2026
3df8bac
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 18, 2026
765e167
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 19, 2026
0dc75bc
add comment
ZackyLake Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
v0::Constant::create(ov::element::i64, ov::Shape{1}, {past_key.get_partial_shape()[2].get_length()}));
past_key = register_new_node<v8::Slice>(past_key, current_kv_len_const, past_kv_len_const, one, two);
past_value = register_new_node<v8::Slice>(past_value, current_kv_len_const, past_kv_len_const, one, two);
} else {
past_key = register_new_node<v8::Slice>(past_key, zero, past_seqlen, one, two);
past_value = register_new_node<v8::Slice>(past_value, zero, past_seqlen, one, two);
}
K = construct_kv_cache(past_key, K);
V = construct_kv_cache(past_value, V);
Expand Down
40 changes: 40 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

/// KVCache with seq_len trimming
KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& past_seq_len,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
const ov::element::Type output_type = ov::element::dynamic);

/// KVCache with seq_len trimming and beam_idx
KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const Output<Node>& past_seq_len,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

/// KVCache with update&trimming for tree-based speculative decoding
KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& past_seq_len,
const Output<Node>& dst_idx,
const Output<Node>& update_data,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
const ov::element::Type output_type = ov::element::dynamic);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
Expand All @@ -51,18 +79,30 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }
bool get_trim() const { return m_trim; }
bool get_update_kv() const { return m_update_kv; }

void set_trim(bool trim) { m_trim = trim; }
void set_update_kv(bool update_kv) { m_update_kv = update_kv; }

int64_t get_trim_length() const { return m_trim_length; }
void set_trim_length(int64_t trim_length) { m_trim_length = trim_length; }

protected:
KVCache(const OutputVector& inputs,
const std::shared_ptr<ov::op::util::Variable>& past_values,
bool indirect,
bool trim,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;
bool m_trim = false;
bool m_update_kv = false;
int64_t m_trim_length = 0;

ov::element::Type m_output_type;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class KVCacheCompressed : public ov::intel_gpu::op::KVCache {

KVCacheCompressed(const OutputVector& inputs,
const std::shared_ptr<ov::op::util::Variable>& past_values,
bool trim,
int64_t concat_axis,
int64_t gather_axis,
const QuantizationAttrs& quantization_attrs,
Expand Down
23 changes: 20 additions & 3 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,27 @@ struct kv_cache : public primitive_base<kv_cache> {
const ov::op::util::VariableInfo& variable_info,
const int64_t concat_axis,
const int64_t gather_axis,
const bool indirect)
const bool indirect,
const bool trim,
const bool update_kv)
: primitive_base(id, inputs)
, variable_info(variable_info)
, concat_axis(concat_axis)
, gather_axis(gather_axis)
, indirect(indirect) {}
, indirect(indirect)
, trim(trim)
, update_kv(update_kv) {
if (update_kv){
OPENVINO_ASSERT(trim, "update_kv must use trim");
}
}

ov::op::util::VariableInfo variable_info;
int64_t concat_axis = 0;
int64_t gather_axis = 0;
bool indirect = false;

bool trim = false;
bool update_kv = false;
bool compressed = false;
QuantizationAttributes quantization_attributes;

Expand All @@ -47,6 +56,8 @@ struct kv_cache : public primitive_base<kv_cache> {
seed = hash_combine(seed, concat_axis);
seed = hash_combine(seed, gather_axis);
seed = hash_combine(seed, indirect);
seed = hash_combine(seed, trim);
seed = hash_combine(seed, update_kv);
seed = hash_combine(seed, compressed);
seed = hash_range(seed, quantization_attributes.scales_zp_output_order.begin(), quantization_attributes.scales_zp_output_order.end());
seed = hash_range(seed, quantization_attributes.group_sizes.begin(), quantization_attributes.group_sizes.end());
Expand All @@ -69,6 +80,8 @@ struct kv_cache : public primitive_base<kv_cache> {
concat_axis == rhs_casted.concat_axis &&
gather_axis == rhs_casted.gather_axis &&
indirect == rhs_casted.indirect &&
trim == rhs_casted.trim &&
update_kv == rhs_casted.update_kv &&
compressed == rhs_casted.compressed &&
quantization_attributes.scales_zp_output_order == rhs_casted.quantization_attributes.scales_zp_output_order &&
quantization_attributes.output_storage_type == rhs_casted.quantization_attributes.output_storage_type &&
Expand All @@ -88,6 +101,8 @@ struct kv_cache : public primitive_base<kv_cache> {
ob << concat_axis;
ob << gather_axis;
ob << indirect;
ob << trim;
ob << update_kv;
ob << compressed;
ob << make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ob << make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand All @@ -110,6 +125,8 @@ struct kv_cache : public primitive_base<kv_cache> {
ib >> concat_axis;
ib >> gather_axis;
ib >> indirect;
ib >> trim;
ib >> update_kv;
ib >> compressed;
ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,15 +444,37 @@ bool crop_in_place_optimization::can_crop_be_optimized_simple_data_format(const
}

static bool can_read_value_be_optimize(const read_value_node& node) {
std::unordered_set<const cldnn::program_node*> unique_users(node.get_users().begin(), node.get_users().end());
if (unique_users.size() == 1)
std::unordered_set<const cldnn::program_node*> unique_users;
for (const auto user : node.get_users()) {
if (!user->is_type<shape_of>()) {
unique_users.insert(user);
}
}
if (unique_users.size() <= 1)
return true;

const auto non_shape_of_users_count = std::count_if(unique_users.begin(), unique_users.end(), [](const program_node* user) {
return !user->is_type<shape_of>();
});
if (non_shape_of_users_count <= 1)
return true;
// following pattern should be optimized, otherwise it could lead to corruptted data.
// readvalue's users eventually need to pass kvcache before assign, which makes kvcache node the dominator of assign node,
// it could be safely treated as if readvalue is directly connecting to kvcache.
// readvalue --> any
// | |
// | v
// ------> kvcache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you elaborate more why/how it can be optimized?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If read_value is not optimized, we will get incorrect result among scatterelementupdate, so some change here is needed.

Original code is simply checking if readvalue is being used by single user, to be honest I don't know if it can prove anything --- that user could be actually a no-op with multiple further users.

From the comment in its caller, looks like it's actually trying to ensure assign will not impact any following user of readvalue, the original logic looks not very promising already.

Anyway, for our case, readvalue's user eventually need to pass kvcache before assign, which makes kvcache node the dominator of assign node, so it could be safely treated as if readvalue is directly connecting to kvcache, and could be optimized.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, my ask here was to add comment on "why/how". As it is not blocking code merge, could you follow-up as a separate PR?

if (unique_users.size() == 2) {
const auto user0 = *unique_users.begin();
const auto user1 = *(++unique_users.begin());
const bool is_user0_kvcache = user0->is_type<kv_cache>();
const auto kvcache = is_user0_kvcache ? user0 : (user1->is_type<kv_cache>() ? user1 : nullptr);
if (kvcache) {
const auto other_user = is_user0_kvcache ? user1 : user0;
const bool only_used_by_kvcache = std::none_of(other_user->get_users().begin(), other_user->get_users().end(), [kvcache](const auto user) {
return user != kvcache && !user->template is_type<shape_of>();
});
if (only_used_by_kvcache) {
return true;
}
}
}

return false;
}
Expand Down
Loading
Loading