-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[GPU] Add the capability for KV cache to update past KV #33114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
5a95cc6
84d8095
ab5be7c
a0ed479
1c72628
4d67227
ff94cfe
c85841f
817c983
b9d5f30
07c75d8
1586380
d338bcc
648a5fc
d7043fe
145e0f5
e01cedf
4c2d73a
9db4cc7
73a739e
7914ddd
8e74647
9bfb862
5326963
f321f9b
ea925ed
6cc91a1
561bc54
d836186
6344883
4d06b8e
644cc75
abcdf72
a1ecd57
f8f1a58
3d65d54
313a748
c429feb
6781639
6a3bbd2
3df8bac
765e167
0dc75bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,4 +1,4 @@ | ||||||
| // Copyright (C) 2023-2024 Intel Corporation | ||||||
| // Copyright (C) 2023-2025 Intel Corporation | ||||||
| // SPDX-License-Identifier: Apache-2.0 | ||||||
| // | ||||||
|
|
||||||
|
|
@@ -16,6 +16,8 @@ | |||||
| #include "beam_table_update/beam_table_update_kernel_ref.hpp" | ||||||
| #include "dynamic_quantize/dynamic_quantize_kernel_selector.h" | ||||||
| #include "dynamic_quantize/dynamic_quantize_kernel_kv_cache.h" | ||||||
| #include "reorder_kv_cache/reorder_kv_cache_kernel_selector.hpp" | ||||||
| #include "reorder_kv_cache/reorder_kv_cache_kernel_ref.hpp" | ||||||
| #include "openvino/core/dimension.hpp" | ||||||
|
|
||||||
| #include <limits.h> | ||||||
|
|
@@ -71,34 +73,38 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> { | |||||
| return make_deep_copy<kv_cache_impl, kernel_params_t>(*this); | ||||||
| } | ||||||
|
|
||||||
| const size_t concat_stage = 0; | ||||||
| const size_t beam_table_stage = 1; | ||||||
| const size_t dq_stage = 2; | ||||||
| const size_t scale_concat_stage = 3; | ||||||
| const size_t zp_concat_stage = 4; | ||||||
| const size_t reorder_trim_stage = 0; | ||||||
| const size_t concat_stage = 1; | ||||||
| const size_t beam_table_stage = 2; | ||||||
| const size_t dq_stage = 3; | ||||||
| const size_t scale_concat_stage = 4; | ||||||
| const size_t zp_concat_stage = 5; | ||||||
|
|
||||||
| cldnn::memory::ptr beam_table_prev = nullptr; | ||||||
| cldnn::memory::ptr beam_table_new = nullptr; | ||||||
| size_t indirect_offset = 0; | ||||||
|
|
||||||
| void load(BinaryInputBuffer& ib) override { | ||||||
| parent::load(ib); | ||||||
| if (is_dynamic()) { | ||||||
| auto& kernel_selector = kernel_selector_t::Instance(); | ||||||
| auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName); | ||||||
| kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]); | ||||||
| if (_kernels_data.size() >= 2) { | ||||||
| auto reorder_kernel_impl = kernel_selector.GetImplementation(_kernels_data[reorder_trim_stage].kernelName); | ||||||
| reorder_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[reorder_trim_stage]); | ||||||
| auto concat_kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName); | ||||||
| concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]); | ||||||
| if (_kernels_data.size() >= 3) { | ||||||
| auto& bt_kernel_selector = bt_kernel_selector_t::Instance(); | ||||||
| auto bt_kernel_impl = bt_kernel_selector.GetImplementation(_kernels_data[beam_table_stage].kernelName); | ||||||
| bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[beam_table_stage]); | ||||||
| } | ||||||
|
|
||||||
| if (_kernels_data.size() >= 3) { | ||||||
| if (_kernels_data.size() >= 4) { | ||||||
| auto& dq_kernel_selector = dq_kernel_selector_t::Instance(); | ||||||
| auto dq_kernel_impl = dq_kernel_selector.GetImplementation(_kernels_data[dq_stage].kernelName); | ||||||
| dq_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[dq_stage]); | ||||||
| } | ||||||
|
|
||||||
| if (_kernels_data.size() >= 4) { | ||||||
| if (_kernels_data.size() >= 5) { | ||||||
| auto& scale_zp_concat_kernel_selector = kernel_selector_t::Instance(); | ||||||
| auto scale_zp_concat_kernel_impl = scale_zp_concat_kernel_selector.GetImplementation(_kernels_data[scale_concat_stage].kernelName); | ||||||
| scale_zp_concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[scale_concat_stage]); | ||||||
|
|
@@ -112,7 +118,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> { | |||||
| // output buffers order: [current, (beam_table), (current_scale), (current_zp)] | ||||||
| kernel_arguments_data args; | ||||||
| args.shape_info = instance.shape_info_memory_ptr(); | ||||||
| if (stage == concat_stage) { | ||||||
| if (stage == reorder_trim_stage) { | ||||||
| args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(3 + indirect_offset), instance.input_memory_ptr(4 + indirect_offset) }; | ||||||
| args.outputs = {instance.input_memory_ptr(0)}; | ||||||
| } else if (stage == concat_stage) { | ||||||
| args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(1) }; | ||||||
| args.outputs = { instance.output_memory_ptr(0) }; | ||||||
| } else if (stage == beam_table_stage) { | ||||||
|
|
@@ -186,10 +195,17 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> { | |||||
| const auto& desc = instance.get_typed_desc<kv_cache>(); | ||||||
| auto& variable = instance.get_network().get_variable(desc->variable_info.variable_id); | ||||||
| std::vector<event::ptr> res_events; | ||||||
| const auto& impl_param = *instance.get_impl_params(); | ||||||
|
|
||||||
| if (impl_param.input_layouts.size() >= 3) { | ||||||
| indirect_offset = desc->indirect ? 1 : 0; | ||||||
| if (instance.input_memory_ptr(0) && instance.input_memory_ptr(3 + indirect_offset)->size()) { | ||||||
| execute_stage(events, instance, res_events, reorder_trim_stage); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| execute_stage(events, instance, res_events, concat_stage); | ||||||
|
|
||||||
| const auto& impl_param = *instance.get_impl_params(); | ||||||
| const auto& kv_in_shape = impl_param.input_layouts[0].get_partial_shape(); | ||||||
| const auto& kv_out_shape = impl_param.output_layouts[0].get_partial_shape(); | ||||||
| if (desc->indirect && ((kv_out_shape[desc->gather_axis].get_length() > 1) || | ||||||
|
|
@@ -296,6 +312,38 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> { | |||||
| return layout{beam_table_shape, impl_param.output_layouts[1].data_type, format::get_default_format(beam_table_shape.size())}; | ||||||
| } | ||||||
|
|
||||||
| static kernel_selector::reorder_kv_cache_params get_reorder_trim_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { | ||||||
| const auto& primitive = impl_param.typed_desc<kv_cache>(); | ||||||
|
||||||
| const auto& primitive = impl_param.typed_desc<kv_cache>(); | |
| auto primitive = impl_param.typed_desc<kv_cache>(); |
Kotomi-Du marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,13 @@ struct typed_program_node<kv_cache> : public typed_program_node_base<kv_cache> { | |
|
|
||
| program_node& input() const { return get_dependency(0); } | ||
|
|
||
| std::vector<size_t> get_shape_infer_dependencies() const override { return {}; } | ||
| std::vector<size_t> get_shape_infer_dependencies() const override { | ||
| std::vector<size_t> vec; | ||
| for (size_t i = 1; i < get_dependencies().size(); i++) { | ||
|
||
| vec.push_back(i); | ||
| } | ||
| return vec; | ||
| } | ||
|
|
||
| std::vector<layout> get_shape_info_input_layouts() const override { | ||
| std::vector<layout> res; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include "primitive_type_base.h" | ||
| #include <sstream> | ||
| #include <json_object.h> | ||
| #include "utils.hpp" | ||
|
|
||
| namespace cldnn { | ||
| GPU_DEFINE_PRIMITIVE_TYPE_ID(kv_cache) | ||
|
|
@@ -31,6 +32,17 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no | |
|
|
||
| std::vector<ShapeType> input_shapes = {impl_param.get_input_layout(0).get<ShapeType>(), | ||
| impl_param.get_input_layout(1).get<ShapeType>()}; | ||
|
|
||
| std::unordered_map<size_t, ov::Tensor> const_data; | ||
| if (desc->trim) { | ||
| if(impl_param.memory_deps.count(2) > 0) | ||
|
||
| { | ||
| auto past_seq_len_mem = impl_param.memory_deps.at(2); | ||
| cldnn::mem_lock<uint8_t, mem_lock_type::read> past_seq_len_mem_lock(past_seq_len_mem, impl_param.get_stream()); | ||
| const_data.emplace(1, make_tensor(past_seq_len_mem->get_layout(), past_seq_len_mem_lock.data())); | ||
| } | ||
| } | ||
|
|
||
| if (desc->indirect) { | ||
| input_shapes.push_back(impl_param.get_input_layout(2).get<ShapeType>()); | ||
| } | ||
|
|
@@ -50,14 +62,44 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no | |
| op.set_concat_axis(desc->concat_axis); | ||
| op.set_gather_axis(desc->gather_axis); | ||
| op.set_quantization_attrs(desc->quantization_attributes); | ||
| if (desc->trim) { | ||
| if (auto past_dim_updated = ov::op::get_input_const_data_as<ov::PartialShape, int64_t>(&op, 1, ov::make_tensor_accessor(const_data))) { | ||
| auto past_dim_stored = input_shapes[0][desc->concat_axis]; | ||
| if (past_dim_stored.is_static()) { | ||
| auto trim_length = past_dim_stored.get_length() - (*past_dim_updated)[0]; | ||
| if (trim_length > 0) { | ||
| op.set_trim_length(static_cast<uint64_t>(trim_length)); | ||
| impl_param.kv_cache_trim_length = trim_length; | ||
| } else { | ||
| op.set_trim_length(static_cast<uint64_t>(0)); | ||
| impl_param.kv_cache_trim_length = 0; | ||
|
|
||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| output_shapes = shape_infer(&op, input_shapes); | ||
| } else { | ||
| ov::intel_gpu::op::KVCache op; | ||
| op.set_output_size(desc->num_outputs); | ||
| op.set_concat_axis(desc->concat_axis); | ||
| op.set_gather_axis(desc->gather_axis); | ||
|
|
||
| if (desc->trim) { | ||
| if (auto past_dim_updated = ov::op::get_input_const_data_as<ov::PartialShape, int64_t>(&op, 1, ov::make_tensor_accessor(const_data))) { | ||
| auto past_dim_stored = input_shapes[0][desc->concat_axis]; | ||
| if (past_dim_stored.is_static()) { | ||
| auto trim_length = past_dim_stored.get_length() - (*past_dim_updated)[0]; | ||
| if (trim_length > 0) { | ||
| op.set_trim_length(static_cast<uint64_t>(trim_length)); | ||
| impl_param.kv_cache_trim_length = trim_length; | ||
| } else { | ||
| op.set_trim_length(static_cast<uint64_t>(0)); | ||
| impl_param.kv_cache_trim_length = 0; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output_shapes = shape_infer(&op, input_shapes); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a right place to add such kv-cache specific field.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you suggest any other place for putting this? Here is the investigation on our side.
The table below shows all the available KVCache related files, but none of them are suitable to put this parameter which requires to be updated in each iteration in runtime.
Specifically, for
kv_cache_inst.h, kv_cache_inst::trim_length couldn't be updated in static functioncalc_output_layout(). It also doesn't make sense to set kv_cache_inst::trim_length as static to make it work, because it will lead to data race across multi-kv-instances or multi-threads.Furthermore,
kernel_imp_params.halso includes other op-specific variables with TODO comment (prior-box).So, it seems acceptable in our case as well.
ov::intel_gpu::op::KVCachesrc/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hppcldnn::kv_cachesrc/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hppcldnn::typed_primitive_inst<kv_cache>src/plugins/intel_gpu/src/graph/include/kv_cache_inst.hThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about introducing separate method in kv_cache_inst.h? You can make a non-static method and use it to store information in primitive_inst. Then this API can be just called after shape inference. You should not place this field in kernel_impl_param.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated implementation looks good for trim_length.