Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5a95cc6
fuse GQA slice node into kvCache for in-place crop
Kotomi-Du Nov 25, 2025
84d8095
fix conformance issue
Kotomi-Du Dec 2, 2025
ab5be7c
Use RemoteTensor to reorder KV cache
mdvoretc-intel Nov 17, 2025
a0ed479
Add a kernel to reorder KV cache
mdvoretc-intel Nov 19, 2025
1c72628
Add KVCache index fusion for reorder
mdvoretc-intel Nov 20, 2025
4d67227
Fix basic issues
mdvoretc-intel Nov 21, 2025
ff94cfe
Prevent KV reorder execution for cases where it's not required
mdvoretc-intel Nov 21, 2025
c85841f
Fix scalar arguments bug, remove debug prints
mdvoretc-intel Nov 27, 2025
817c983
Remove unused gather_by_axis code
mdvoretc-intel Dec 2, 2025
b9d5f30
Fix input offsets
mdvoretc-intel Dec 3, 2025
07c75d8
Add unit test case
mdvoretc-intel Dec 3, 2025
1586380
Add feature bounds check
mdvoretc-intel Dec 3, 2025
d338bcc
clean up code
Kotomi-Du Dec 4, 2025
648a5fc
clean up execution stage
Kotomi-Du Dec 5, 2025
d7043fe
use scatterElementUpdate kernel instead of self customized kernel
Kotomi-Du Dec 5, 2025
145e0f5
delete customized kernel path
Kotomi-Du Dec 5, 2025
e01cedf
clean up code
Kotomi-Du Dec 6, 2025
4c2d73a
fix code style
Kotomi-Du Dec 10, 2025
9db4cc7
adjust index for compressed KV stage when update_kv stage is existed
Kotomi-Du Dec 10, 2025
73a739e
refactor tests, merge duplicated code
ZackyLake Dec 18, 2025
7914ddd
refactor kvcache stage.
ZackyLake Dec 18, 2025
8e74647
remove update_kv logic on compress kv.
ZackyLake Dec 30, 2025
9bfb862
remove indirect support on kv_update due to lack of test.
ZackyLake Dec 30, 2025
5326963
add debug priont for skipped kernel
ZackyLake Jan 7, 2026
f321f9b
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 7, 2026
ea925ed
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 7, 2026
6cc91a1
fix kv fusion pattern。
ZackyLake Jan 10, 2026
561bc54
Merge branch 'master' into update_kvcache_node
ZackyLake Jan 10, 2026
d836186
fix test
ZackyLake Jan 12, 2026
6344883
move trim_length to kv_cache_inst.h
Kotomi-Du Jan 7, 2026
4d06b8e
fix kv fusion for test(stridedslice)
ZackyLake Jan 15, 2026
644cc75
include trim-only support
ZackyLake Jan 15, 2026
abcdf72
fix concat_axis signness
ZackyLake Jan 15, 2026
a1ecd57
fix fusion logic
ZackyLake Jan 15, 2026
f8f1a58
fix signedness
ZackyLake Jan 15, 2026
3d65d54
allow trim on indirect kvcache.
ZackyLake Jan 16, 2026
313a748
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 16, 2026
c429feb
Make CompressedKV compatible with trim.
ZackyLake Jan 17, 2026
6781639
fix
ZackyLake Jan 17, 2026
6a3bbd2
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 17, 2026
3df8bac
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 18, 2026
765e167
Merge branch 'master' into update_kvcache_node
Kotomi-Du Jan 19, 2026
0dc75bc
add comment
ZackyLake Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ struct kernel_impl_params final {
std::vector<size_t> output_size;
std::vector<size_t> img_size;

// KV cache trim length - set at runtime during shape inference
// Marked as mutable to allow modification even when kernel_impl_params is passed as const reference
mutable int64_t kv_cache_trim_length = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a right place to add such kv-cache specific field.

Copy link
Contributor Author

@Kotomi-Du Kotomi-Du Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you suggest any other place for putting this? Here is the investigation on our side.

The table below shows all the available KVCache related files, but none of them are suitable to put this parameter which requires to be updated in each iteration in runtime.

Specifically, for kv_cache_inst.h, kv_cache_inst::trim_length couldn't be updated in static function calc_output_layout() . It also doesn't make sense to set kv_cache_inst::trim_length as static to make it work, because it will lead to data race across multi-kv-instances or multi-threads.

Furthermore, kernel_imp_params.h also includes other op-specific variables with TODO comment (prior-box).
So, it seems acceptable in our case as well.

Description File Path Notes
ov::intel_gpu::op::KVCache src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp it cannot server for runtime
cldnn::kv_cache src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp it stores static attributes and can be serialized into the kernel's binary cache
cldnn::typed_primitive_inst<kv_cache> src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h node instance

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about introducing separate method in kv_cache_inst.h? You can make a non-static method and use it to store information in primitive_inst. Then this API can be just called after shape inference. You should not place this field in kernel_impl_param.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated implementation looks good for trim_length.


std::map<size_t, size_t> in_port_to_shape_info_offset = {};
std::map<size_t, size_t> out_port_to_shape_info_offset = {};

Expand Down
26 changes: 26 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& split_seq,
const Output<Node>& src_idx,
const Output<Node>& dst_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
const ov::element::Type output_type = ov::element::dynamic);

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const Output<Node>& split_seq,
const Output<Node>& src_idx,
const Output<Node>& dst_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::dynamic);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
Expand All @@ -51,6 +71,10 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }
bool get_trim() const { return m_trim; }

uint64_t get_trim_length() const { return m_trim_length; }
void set_trim_length(uint64_t trim_length) { m_trim_length = trim_length; }

protected:
KVCache(const OutputVector& inputs,
Expand All @@ -63,6 +87,8 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;
bool m_trim = false;
uint64_t m_trim_length = 0;

ov::element::Type m_output_type;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ struct kv_cache : public primitive_base<kv_cache> {
const ov::op::util::VariableInfo& variable_info,
const int64_t concat_axis,
const int64_t gather_axis,
const bool indirect)
const bool indirect,
const bool trim)
: primitive_base(id, inputs)
, variable_info(variable_info)
, concat_axis(concat_axis)
, gather_axis(gather_axis)
, indirect(indirect) {}
, indirect(indirect)
, trim(trim) {}

ov::op::util::VariableInfo variable_info;
int64_t concat_axis = 0;
int64_t gather_axis = 0;
bool indirect = false;

bool trim = false;
bool compressed = false;
QuantizationAttributes quantization_attributes;

Expand All @@ -47,6 +49,7 @@ struct kv_cache : public primitive_base<kv_cache> {
seed = hash_combine(seed, concat_axis);
seed = hash_combine(seed, gather_axis);
seed = hash_combine(seed, indirect);
seed = hash_combine(seed, trim);
seed = hash_combine(seed, compressed);
seed = hash_range(seed, quantization_attributes.scales_zp_output_order.begin(), quantization_attributes.scales_zp_output_order.end());
seed = hash_range(seed, quantization_attributes.group_sizes.begin(), quantization_attributes.group_sizes.end());
Expand All @@ -69,6 +72,7 @@ struct kv_cache : public primitive_base<kv_cache> {
concat_axis == rhs_casted.concat_axis &&
gather_axis == rhs_casted.gather_axis &&
indirect == rhs_casted.indirect &&
trim == rhs_casted.trim &&
compressed == rhs_casted.compressed &&
quantization_attributes.scales_zp_output_order == rhs_casted.quantization_attributes.scales_zp_output_order &&
quantization_attributes.output_storage_type == rhs_casted.quantization_attributes.output_storage_type &&
Expand All @@ -88,6 +92,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ob << concat_axis;
ob << gather_axis;
ob << indirect;
ob << trim;
ob << compressed;
ob << make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ob << make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand All @@ -110,6 +115,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ib >> concat_axis;
ib >> gather_axis;
ib >> indirect;
ib >> trim;
ib >> compressed;
ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type));
ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt));
Expand Down
92 changes: 78 additions & 14 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023-2024 Intel Corporation
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand All @@ -16,6 +16,8 @@
#include "beam_table_update/beam_table_update_kernel_ref.hpp"
#include "dynamic_quantize/dynamic_quantize_kernel_selector.h"
#include "dynamic_quantize/dynamic_quantize_kernel_kv_cache.h"
#include "reorder_kv_cache/reorder_kv_cache_kernel_selector.hpp"
#include "reorder_kv_cache/reorder_kv_cache_kernel_ref.hpp"
#include "openvino/core/dimension.hpp"

#include <limits.h>
Expand Down Expand Up @@ -71,34 +73,38 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return make_deep_copy<kv_cache_impl, kernel_params_t>(*this);
}

const size_t concat_stage = 0;
const size_t beam_table_stage = 1;
const size_t dq_stage = 2;
const size_t scale_concat_stage = 3;
const size_t zp_concat_stage = 4;
const size_t reorder_trim_stage = 0;
const size_t concat_stage = 1;
const size_t beam_table_stage = 2;
const size_t dq_stage = 3;
const size_t scale_concat_stage = 4;
const size_t zp_concat_stage = 5;

cldnn::memory::ptr beam_table_prev = nullptr;
cldnn::memory::ptr beam_table_new = nullptr;
size_t indirect_offset = 0;

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
auto& kernel_selector = kernel_selector_t::Instance();
auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName);
kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]);
if (_kernels_data.size() >= 2) {
auto reorder_kernel_impl = kernel_selector.GetImplementation(_kernels_data[reorder_trim_stage].kernelName);
reorder_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[reorder_trim_stage]);
auto concat_kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName);
concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]);
if (_kernels_data.size() >= 3) {
auto& bt_kernel_selector = bt_kernel_selector_t::Instance();
auto bt_kernel_impl = bt_kernel_selector.GetImplementation(_kernels_data[beam_table_stage].kernelName);
bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[beam_table_stage]);
}

if (_kernels_data.size() >= 3) {
if (_kernels_data.size() >= 4) {
auto& dq_kernel_selector = dq_kernel_selector_t::Instance();
auto dq_kernel_impl = dq_kernel_selector.GetImplementation(_kernels_data[dq_stage].kernelName);
dq_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[dq_stage]);
}

if (_kernels_data.size() >= 4) {
if (_kernels_data.size() >= 5) {
auto& scale_zp_concat_kernel_selector = kernel_selector_t::Instance();
auto scale_zp_concat_kernel_impl = scale_zp_concat_kernel_selector.GetImplementation(_kernels_data[scale_concat_stage].kernelName);
scale_zp_concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[scale_concat_stage]);
Expand All @@ -112,7 +118,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
// output buffers order: [current, (beam_table), (current_scale), (current_zp)]
kernel_arguments_data args;
args.shape_info = instance.shape_info_memory_ptr();
if (stage == concat_stage) {
if (stage == reorder_trim_stage) {
args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(3 + indirect_offset), instance.input_memory_ptr(4 + indirect_offset) };
args.outputs = {instance.input_memory_ptr(0)};
} else if (stage == concat_stage) {
args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(1) };
args.outputs = { instance.output_memory_ptr(0) };
} else if (stage == beam_table_stage) {
Expand Down Expand Up @@ -186,10 +195,17 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto& desc = instance.get_typed_desc<kv_cache>();
auto& variable = instance.get_network().get_variable(desc->variable_info.variable_id);
std::vector<event::ptr> res_events;
const auto& impl_param = *instance.get_impl_params();

if (impl_param.input_layouts.size() >= 3) {
indirect_offset = desc->indirect ? 1 : 0;
if (instance.input_memory_ptr(0) && instance.input_memory_ptr(3 + indirect_offset)->size()) {
execute_stage(events, instance, res_events, reorder_trim_stage);
}
}

execute_stage(events, instance, res_events, concat_stage);

const auto& impl_param = *instance.get_impl_params();
const auto& kv_in_shape = impl_param.input_layouts[0].get_partial_shape();
const auto& kv_out_shape = impl_param.output_layouts[0].get_partial_shape();
if (desc->indirect && ((kv_out_shape[desc->gather_axis].get_length() > 1) ||
Expand Down Expand Up @@ -296,6 +312,38 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return layout{beam_table_shape, impl_param.output_layouts[1].data_type, format::get_default_format(beam_table_shape.size())};
}

static kernel_selector::reorder_kv_cache_params get_reorder_trim_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<kv_cache>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typed_desc returns std::shared_ptr, binding it like this might create a dangling reference.

Suggested change
const auto& primitive = impl_param.typed_desc<kv_cache>();
auto primitive = impl_param.typed_desc<kv_cache>();

auto params = get_default_params<kernel_selector::reorder_kv_cache_params>(impl_param, is_shape_agnostic);

auto inputs_count = 3;

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(impl_param.input_layouts[0], tensor());
params.inputs[1] = convert_data_tensor(impl_param.input_layouts[3 + (impl_param.typed_desc<kv_cache>()->indirect ? 1 : 0)], tensor());
params.inputs[2] = convert_data_tensor(impl_param.input_layouts[4 + (impl_param.typed_desc<kv_cache>()->indirect ? 1 : 0)], tensor());
params.outputs[0] = convert_data_tensor(impl_param.output_layouts[0], tensor());
params.seq_len = params.inputs[0].Y().pitch ? params.inputs[0].Feature().pitch / params.inputs[0].Y().pitch : 0;
params.idx_len = params.inputs[2].Y().v;

const auto& desc = impl_param.typed_desc<kv_cache>();

const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, [scale_past], [zp_past], beam_table_past]]
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present, compression_scale_present]
std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(0)}, // kv_past
{1, in_offsets_map.at(3 + impl_param.typed_desc<kv_cache>()->indirect ? 1 : 0)}, // src_idx
{2, in_offsets_map.at(4 + impl_param.typed_desc<kv_cache>()->indirect ? 1 : 0)}, // dst_idx
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, in_offsets_map.at(0)}, // kv_present
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);

return params;
}

static kernel_params_t get_concat_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<kv_cache>();
auto params = get_default_params<kernel_selector::concatenation_params>(impl_param, is_shape_agnostic);
Expand All @@ -304,7 +352,14 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto inputs_count = 2;
params.inputs.resize(inputs_count);
for (size_t i = 0; i < inputs_count; ++i) {
params.inputs[i] = convert_data_tensor(impl_param.input_layouts[i]);
auto target_layout = impl_param.input_layouts[i];
// Trim the cache
/*if (i == 0 && primitive->trim) {
auto shape = target_layout.get_partial_shape();
shape[axis] = shape[axis] - primitive->trim;
target_layout.set_partial_shape(shape);
}*/
params.inputs[i] = convert_data_tensor(target_layout);
}

params.axis = convert_axis(axis, impl_param.get_output_layout().get_rank());
Expand Down Expand Up @@ -454,6 +509,11 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

static std::unique_ptr<primitive_impl> create(const typed_program_node<kv_cache>& arg, const kernel_impl_params& impl_param) {
std::vector<kernel_selector::kernel_data> kernels_data;
if (impl_param.typed_desc<kv_cache>()->input.size() >= 3) {
auto reorder_kernel_params = get_reorder_trim_kernel_params(impl_param, impl_param.is_dynamic());
auto& reorder_kernel_selector = kernel_selector::reorder_kv_cache_kernel_selector::Instance();
kernels_data.push_back(reorder_kernel_selector.get_best_kernel(reorder_kernel_params));
}
auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
Expand Down Expand Up @@ -486,6 +546,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
}

void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto reorder_kernel_params = get_reorder_trim_kernel_params(impl_param, true);
(_kernels_data[reorder_trim_stage].update_dispatch_data_func)(reorder_kernel_params, _kernels_data[reorder_trim_stage]);
_kernels_data[reorder_trim_stage].kernels[0].skip_execution = (reorder_kernel_params.seq_len == 0) || (reorder_kernel_params.idx_len == 0);

// If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future
if (_kernels_data[concat_stage].params == nullptr) {
_kernels_data[concat_stage].params = std::make_shared<kernel_params_t>(get_concat_kernel_params(impl_param, true));
Expand Down
8 changes: 7 additions & 1 deletion src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ struct typed_program_node<kv_cache> : public typed_program_node_base<kv_cache> {

program_node& input() const { return get_dependency(0); }

std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
std::vector<size_t> get_shape_infer_dependencies() const override {
std::vector<size_t> vec;
for (size_t i = 1; i < get_dependencies().size(); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supposed to start from 1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why all dependences became shape_infer_dependencies? I think only the new index node should become shape_infer_dependencies.

vec.push_back(i);
}
return vec;
}

std::vector<layout> get_shape_info_input_layouts() const override {
std::vector<layout> res;
Expand Down
44 changes: 43 additions & 1 deletion src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "primitive_type_base.h"
#include <sstream>
#include <json_object.h>
#include "utils.hpp"

namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(kv_cache)
Expand All @@ -31,6 +32,17 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no

std::vector<ShapeType> input_shapes = {impl_param.get_input_layout(0).get<ShapeType>(),
impl_param.get_input_layout(1).get<ShapeType>()};

std::unordered_map<size_t, ov::Tensor> const_data;
if (desc->trim) {
if(impl_param.memory_deps.count(2) > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a case where desc->update_kv is true and deps.count(2) <= 0? If not, what about replacing it with OPENVINO_ASSERT

{
auto past_seq_len_mem = impl_param.memory_deps.at(2);
cldnn::mem_lock<uint8_t, mem_lock_type::read> past_seq_len_mem_lock(past_seq_len_mem, impl_param.get_stream());
const_data.emplace(1, make_tensor(past_seq_len_mem->get_layout(), past_seq_len_mem_lock.data()));
}
}

if (desc->indirect) {
input_shapes.push_back(impl_param.get_input_layout(2).get<ShapeType>());
}
Expand All @@ -50,14 +62,44 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no
op.set_concat_axis(desc->concat_axis);
op.set_gather_axis(desc->gather_axis);
op.set_quantization_attrs(desc->quantization_attributes);
if (desc->trim) {
if (auto past_dim_updated = ov::op::get_input_const_data_as<ov::PartialShape, int64_t>(&op, 1, ov::make_tensor_accessor(const_data))) {
auto past_dim_stored = input_shapes[0][desc->concat_axis];
if (past_dim_stored.is_static()) {
auto trim_length = past_dim_stored.get_length() - (*past_dim_updated)[0];
if (trim_length > 0) {
op.set_trim_length(static_cast<uint64_t>(trim_length));
impl_param.kv_cache_trim_length = trim_length;
} else {
op.set_trim_length(static_cast<uint64_t>(0));
impl_param.kv_cache_trim_length = 0;

}
}
}
}

output_shapes = shape_infer(&op, input_shapes);
} else {
ov::intel_gpu::op::KVCache op;
op.set_output_size(desc->num_outputs);
op.set_concat_axis(desc->concat_axis);
op.set_gather_axis(desc->gather_axis);

if (desc->trim) {
if (auto past_dim_updated = ov::op::get_input_const_data_as<ov::PartialShape, int64_t>(&op, 1, ov::make_tensor_accessor(const_data))) {
auto past_dim_stored = input_shapes[0][desc->concat_axis];
if (past_dim_stored.is_static()) {
auto trim_length = past_dim_stored.get_length() - (*past_dim_updated)[0];
if (trim_length > 0) {
op.set_trim_length(static_cast<uint64_t>(trim_length));
impl_param.kv_cache_trim_length = trim_length;
} else {
op.set_trim_length(static_cast<uint64_t>(0));
impl_param.kv_cache_trim_length = 0;
}
}
}
}
output_shapes = shape_infer(&op, input_shapes);
}

Expand Down
Loading