diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp index a47d4a6eaecb4d..14c903a5d20c7b 100644 --- a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -131,6 +131,9 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( v0::Constant::create(ov::element::i64, ov::Shape{1}, {past_key.get_partial_shape()[2].get_length()})); past_key = register_new_node(past_key, current_kv_len_const, past_kv_len_const, one, two); past_value = register_new_node(past_value, current_kv_len_const, past_kv_len_const, one, two); + } else { + past_key = register_new_node(past_key, zero, past_seqlen, one, two); + past_value = register_new_node(past_value, zero, past_seqlen, one, two); } K = construct_kv_cache(past_key, K); V = construct_kv_cache(past_value, V); diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp index 412cf3e538e0f4..87ca2d8b203f27 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp @@ -33,6 +33,34 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { int64_t gather_axis, const ov::element::Type output_type = ov::element::dynamic); + /// KVCache with seq_len trimming + KVCache(const Output& past, + const Output& new_token_data, + const Output& past_seq_len, + const std::shared_ptr& past_values, + int64_t concat_axis, + const ov::element::Type output_type = ov::element::dynamic); + + /// KVCache with seq_len trimming and beam_idx + KVCache(const Output& past, + const Output& new_token_data, + const Output& beam_idx, + const Output& past_seq_len, + const std::shared_ptr& past_values, + int64_t concat_axis, + int64_t gather_axis, + const ov::element::Type output_type = ov::element::dynamic); + + /// KVCache with update&trimming for tree-based speculative decoding + KVCache(const Output& past, + const Output& new_token_data, + const Output& past_seq_len, + const Output& dst_idx, + const Output& update_data, + const std::shared_ptr& past_values, + int64_t concat_axis, + const ov::element::Type output_type = ov::element::dynamic); + bool visit_attributes(ov::AttributeVisitor& visitor) override; void validate_and_infer_types() override; @@ -51,11 +79,20 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { void set_gather_axis(int64_t axis) { m_gather_axis = axis; } bool get_indirect() const { return m_indirect; } + bool get_trim() const { return m_trim; } + bool get_update_kv() const { return m_update_kv; } + + void set_trim(bool trim) { m_trim = trim; } + void set_update_kv(bool update_kv) { m_update_kv = update_kv; } + + int64_t get_trim_length() const { return m_trim_length; } + void set_trim_length(int64_t trim_length) { m_trim_length = trim_length; } protected: KVCache(const OutputVector& inputs, const std::shared_ptr& past_values, bool indirect, + bool trim, int64_t concat_axis, int64_t gather_axis, const ov::element::Type output_type = ov::element::dynamic); @@ -63,6 +100,9 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { int64_t m_concat_axis = 0; int64_t m_gather_axis = 0; bool m_indirect = false; + bool m_trim = false; + bool m_update_kv = false; + int64_t m_trim_length = 0; ov::element::Type m_output_type; }; diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp index 65c397f22b2dfc..3b9191c469db24 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp @@ -21,6 +21,7 @@ class KVCacheCompressed : public ov::intel_gpu::op::KVCache { KVCacheCompressed(const OutputVector& inputs, const std::shared_ptr& past_values, + bool trim, int64_t concat_axis, int64_t gather_axis, const QuantizationAttrs& quantization_attrs, diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp index 9bd5084f0b8a61..2058d88c471f35 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp @@ -27,18 +27,27 @@ struct kv_cache : public primitive_base { const ov::op::util::VariableInfo& variable_info, const int64_t concat_axis, const int64_t gather_axis, - const bool indirect) + const bool indirect, + const bool trim, + const bool update_kv) : primitive_base(id, inputs) , variable_info(variable_info) , concat_axis(concat_axis) , gather_axis(gather_axis) - , indirect(indirect) {} + , indirect(indirect) + , trim(trim) + , update_kv(update_kv) { + if (update_kv){ + OPENVINO_ASSERT(trim, "update_kv must use trim"); + } + } ov::op::util::VariableInfo variable_info; int64_t concat_axis = 0; int64_t gather_axis = 0; bool indirect = false; - + bool trim = false; + bool update_kv = false; bool compressed = false; QuantizationAttributes quantization_attributes; @@ -47,6 +56,8 @@ struct kv_cache : public primitive_base { seed = hash_combine(seed, concat_axis); seed = hash_combine(seed, gather_axis); seed = hash_combine(seed, indirect); + seed = hash_combine(seed, trim); + seed = hash_combine(seed, update_kv); seed = hash_combine(seed, compressed); seed = hash_range(seed, quantization_attributes.scales_zp_output_order.begin(), quantization_attributes.scales_zp_output_order.end()); seed = hash_range(seed, quantization_attributes.group_sizes.begin(), quantization_attributes.group_sizes.end()); @@ -69,6 +80,8 @@ struct kv_cache : public primitive_base { concat_axis == rhs_casted.concat_axis && gather_axis == rhs_casted.gather_axis && indirect == rhs_casted.indirect && + trim == rhs_casted.trim && + update_kv == rhs_casted.update_kv && compressed == rhs_casted.compressed && quantization_attributes.scales_zp_output_order == rhs_casted.quantization_attributes.scales_zp_output_order && quantization_attributes.output_storage_type == rhs_casted.quantization_attributes.output_storage_type && @@ -88,6 +101,8 @@ struct kv_cache : public primitive_base { ob << concat_axis; ob << gather_axis; ob << indirect; + ob << trim; + ob << update_kv; ob << compressed; ob << make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type)); ob << make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt)); @@ -110,6 +125,8 @@ struct kv_cache : public primitive_base { ib >> concat_axis; ib >> gather_axis; ib >> indirect; + ib >> trim; + ib >> update_kv; ib >> compressed; ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type)); ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt)); diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index a36d67bb36cbd7..c2a0af74db3161 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -444,15 +444,37 @@ bool crop_in_place_optimization::can_crop_be_optimized_simple_data_format(const } static bool can_read_value_be_optimize(const read_value_node& node) { - std::unordered_set unique_users(node.get_users().begin(), node.get_users().end()); - if (unique_users.size() == 1) + std::unordered_set unique_users; + for (const auto user : node.get_users()) { + if (!user->is_type()) { + unique_users.insert(user); + } + } + if (unique_users.size() <= 1) return true; - const auto non_shape_of_users_count = std::count_if(unique_users.begin(), unique_users.end(), [](const program_node* user) { - return !user->is_type(); - }); - if (non_shape_of_users_count <= 1) - return true; + // following pattern should be optimized, otherwise it could lead to corruptted data. + // readvalue's users eventually need to pass kvcache before assign, which makes kvcache node the dominator of assign node, + // it could be safely treated as if readvalue is directly connecting to kvcache. + // readvalue --> any + // | | + // | v + // ------> kvcache + if (unique_users.size() == 2) { + const auto user0 = *unique_users.begin(); + const auto user1 = *(++unique_users.begin()); + const bool is_user0_kvcache = user0->is_type(); + const auto kvcache = is_user0_kvcache ? user0 : (user1->is_type() ? user1 : nullptr); + if (kvcache) { + const auto other_user = is_user0_kvcache ? user1 : user0; + const bool only_used_by_kvcache = std::none_of(other_user->get_users().begin(), other_user->get_users().end(), [kvcache](const auto user) { + return user != kvcache && !user->template is_type(); + }); + if (only_used_by_kvcache) { + return true; + } + } + } return false; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp index e7b5e87d3e065b..249b1ec403a4d6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp @@ -16,6 +16,8 @@ #include "beam_table_update/beam_table_update_kernel_ref.hpp" #include "dynamic_quantize/dynamic_quantize_kernel_selector.h" #include "dynamic_quantize/dynamic_quantize_kernel_kv_cache.h" +#include "scatter_update/scatter_elements_update_kernel_selector.h" +#include "scatter_update/scatter_elements_update_kernel_ref.h" #include "openvino/core/dimension.hpp" #include @@ -53,9 +55,54 @@ kernel_selector::concat_axis convert_axis(int64_t axis, size_t rank) { } // namespace +enum class kv_stage : uint8_t { scatter_update, concat, beam_table, dq, scale_concat, zp_concat }; + +struct stages_helper { + std::vector stages; + + void verify() const { + if (try_get_index(kv_stage::scatter_update).has_value()) { + const auto is_indirect = try_get_index(kv_stage::beam_table).has_value(); + const auto is_compress = try_get_index(kv_stage::dq).has_value(); + OPENVINO_ASSERT(!is_indirect && !is_compress); + } + } + + void save(BinaryOutputBuffer& ob) const { + ob << stages.size(); + for (auto& stage : stages) { + ob << static_cast(stage); + } + } + + void load(BinaryInputBuffer& ib) { + size_t stages_size = 0; + ib >> stages_size; + stages.resize(stages_size); + for (auto& stage : stages) { + uint8_t stage_ = 0; + ib >> stage_; + stage = static_cast(stage_); + } + } + + std::optional try_get_index(kv_stage stage) const noexcept { + if (const auto it = std::find(stages.begin(), stages.end(), stage); it != stages.end()) { + return static_cast(std::distance(stages.begin(), it)); + } + return {}; + } + + size_t get_index(kv_stage stage) const noexcept { + const auto idx = try_get_index(stage); + OPENVINO_ASSERT(idx.has_value(), "expect stage ", static_cast(stage), " exist"); + return *idx; + } + +}; + struct kv_cache_impl : multi_stage_primitive { using parent = multi_stage_primitive; - using parent::parent; using kernel_selector_t = kernel_selector::concatenation_kernel_selector; using kernel_params_t = kernel_selector::concatenation_params; @@ -65,43 +112,71 @@ struct kv_cache_impl : multi_stage_primitive { using dq_kernel_selector_t = kernel_selector::dynamic_quantize_kernel_selector; using dq_kernel_params_t = kernel_selector::dynamic_quantize_params; + using scatter_kernel_selector_t = kernel_selector::scatter_elements_update_kernel_selector; + using scatter_kernel_params_t = kernel_selector::scatter_elements_update_params; + DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::kv_cache_impl) + kv_cache_impl() {} + + kv_cache_impl(const kv_cache_impl& other) + : parent(other) + , stages(other.stages) { + } + + kv_cache_impl(const std::vector& kd, const std::vector& stages_) + : parent(kd) + , stages{stages_} { + OPENVINO_ASSERT(_kernels_data.size() == stages.stages.size()); + stages.verify(); + } + std::unique_ptr clone() const override { return make_deep_copy(*this); } - const size_t concat_stage = 0; - const size_t beam_table_stage = 1; - const size_t dq_stage = 2; - const size_t scale_concat_stage = 3; - const size_t zp_concat_stage = 4; + stages_helper stages; cldnn::memory::ptr beam_table_prev = nullptr; cldnn::memory::ptr beam_table_new = nullptr; + void save(BinaryOutputBuffer& ob) const override { + parent::save(ob); + stages.save(ob); + } + void load(BinaryInputBuffer& ib) override { parent::load(ib); + stages.load(ib); + OPENVINO_ASSERT(_kernels_data.size() == stages.stages.size()); + stages.verify(); if (is_dynamic()) { - auto& kernel_selector = kernel_selector_t::Instance(); - auto kernel_impl = kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName); - kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]); - if (_kernels_data.size() >= 2) { + if (const auto scatter_update_stage = stages.try_get_index(kv_stage::scatter_update)) { + auto& scatter_kernel_selector = scatter_kernel_selector_t::Instance(); + auto scatter_kernel_impl = scatter_kernel_selector.GetImplementation(_kernels_data[*scatter_update_stage].kernelName); + scatter_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[*scatter_update_stage]); + } + const auto concat_stage = stages.get_index(kv_stage::concat); + auto& concat_kernel_selector = kernel_selector_t::Instance(); + auto concat_kernel_impl = concat_kernel_selector.GetImplementation(_kernels_data[concat_stage].kernelName); + concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[concat_stage]); + + if (const auto beam_table_stage = stages.try_get_index(kv_stage::beam_table)) { auto& bt_kernel_selector = bt_kernel_selector_t::Instance(); - auto bt_kernel_impl = bt_kernel_selector.GetImplementation(_kernels_data[beam_table_stage].kernelName); - bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[beam_table_stage]); + auto bt_kernel_impl = bt_kernel_selector.GetImplementation(_kernels_data[*beam_table_stage].kernelName); + bt_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[*beam_table_stage]); } - if (_kernels_data.size() >= 3) { + if (const auto dq_stage = stages.try_get_index(kv_stage::dq)) { auto& dq_kernel_selector = dq_kernel_selector_t::Instance(); - auto dq_kernel_impl = dq_kernel_selector.GetImplementation(_kernels_data[dq_stage].kernelName); - dq_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[dq_stage]); + auto dq_kernel_impl = dq_kernel_selector.GetImplementation(_kernels_data[*dq_stage].kernelName); + dq_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[*dq_stage]); } - if (_kernels_data.size() >= 4) { + if (const auto scale_concat_stage = stages.try_get_index(kv_stage::scale_concat)) { auto& scale_zp_concat_kernel_selector = kernel_selector_t::Instance(); - auto scale_zp_concat_kernel_impl = scale_zp_concat_kernel_selector.GetImplementation(_kernels_data[scale_concat_stage].kernelName); - scale_zp_concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[scale_concat_stage]); + auto scale_zp_concat_kernel_impl = scale_zp_concat_kernel_selector.GetImplementation(_kernels_data[*scale_concat_stage].kernelName); + scale_zp_concat_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[*scale_concat_stage]); } } } @@ -112,24 +187,35 @@ struct kv_cache_impl : multi_stage_primitive { // output buffers order: [current, (beam_table), (current_scale), (current_zp)] kernel_arguments_data args; args.shape_info = instance.shape_info_memory_ptr(); - if (stage == concat_stage) { + switch (stages.stages[stage]) { + case kv_stage::scatter_update: + // ScatterElementsUpdate: data=past_kv[0], indices=dst_idx[3+indirect_offset], updates=updated_data[4+indirect_offset] + args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(3), instance.input_memory_ptr(4) }; + args.outputs = { instance.input_memory_ptr(0) }; + break; + case kv_stage::concat: args.inputs = { instance.input_memory_ptr(0), instance.input_memory_ptr(1) }; args.outputs = { instance.output_memory_ptr(0) }; - } else if (stage == beam_table_stage) { + break; + case kv_stage::beam_table: args.inputs = { beam_table_prev, instance.input_memory_ptr(2) }; args.outputs = { beam_table_new }; - } else if (stage == dq_stage) { + break; + case kv_stage::dq: args.inputs = { instance.input_memory_ptr(1) }; args.outputs = { instance.output_memory_ptr(0) }; for (size_t i = 2; i < instance.outputs_memory_count(); i++) { args.outputs.push_back(instance.output_memory_ptr(i)); } - } else if (stage == scale_concat_stage) { + break; + case kv_stage::scale_concat: args.inputs = { instance.input_memory_ptr(3) }; args.outputs = { instance.output_memory_ptr(2) }; - } else if (stage == zp_concat_stage) { + break; + case kv_stage::zp_concat: args.inputs = { instance.input_memory_ptr(4) }; args.outputs = { instance.output_memory_ptr(3) }; + break; } return args; } @@ -146,8 +232,10 @@ struct kv_cache_impl : multi_stage_primitive { kernel_offset += _kernels_data[s].kernels.size(); } for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) { - if (_kernels_data[stage].kernels[kd_idx].skip_execution) + if (_kernels_data[stage].kernels[kd_idx].skip_execution) { + GPU_DEBUG_TRACE_DETAIL << "Skip stage " << stage << " kernel " << kd_idx << std::endl; continue; + } size_t idx_final = kernel_offset + kd_idx; // If any user of the prim's users is CPU implementation or network's output, set prim as a output event (event won't be nullptr) @@ -186,14 +274,20 @@ struct kv_cache_impl : multi_stage_primitive { const auto& desc = instance.get_typed_desc(); auto& variable = instance.get_network().get_variable(desc->variable_info.variable_id); std::vector res_events; + const auto& impl_param = *instance.get_impl_params(); + if (const auto scatter_update_stage = stages.try_get_index(kv_stage::scatter_update)) { + execute_stage(events, instance, res_events, *scatter_update_stage); + } + const auto concat_stage = stages.get_index(kv_stage::concat); execute_stage(events, instance, res_events, concat_stage); - const auto& impl_param = *instance.get_impl_params(); const auto& kv_in_shape = impl_param.input_layouts[0].get_partial_shape(); const auto& kv_out_shape = impl_param.output_layouts[0].get_partial_shape(); if (desc->indirect && ((kv_out_shape[desc->gather_axis].get_length() > 1) || (kv_in_shape[desc->concat_axis].get_length() == 0))) { + const auto beam_table_stage = stages.get_index(kv_stage::beam_table); + const auto bt_alloc_type = engine.get_preferred_memory_allocation_type(false); auto beam_table_state = dynamic_cast(variable).get_beam_table_state(); @@ -226,14 +320,17 @@ struct kv_cache_impl : multi_stage_primitive { } if (desc->compressed) { + const auto scale_concat_stage = stages.get_index(kv_stage::scale_concat); // Copy scales to the new buffer if needed - execute_stage(events, instance, res_events, scale_concat_stage, scale_concat_stage); + execute_stage(events, instance, res_events, scale_concat_stage); if (desc->get_compression_zp_inputs_num() > 0) { + const auto zp_concat_stage = stages.get_index(kv_stage::zp_concat); // Copy zero points to the new buffer if needed - execute_stage(events, instance, res_events, zp_concat_stage, zp_concat_stage); + execute_stage(events, instance, res_events, zp_concat_stage); } + const auto dq_stage = stages.get_index(kv_stage::dq); // Perform dynamic quantization of new token data and append result to the KV-cache auto dq_params = get_dq_update_kernel_params(impl_param, impl_param.is_dynamic()); (_kernels_data[dq_stage].update_dispatch_data_func)(dq_params, _kernels_data[dq_stage]); @@ -296,6 +393,37 @@ struct kv_cache_impl : multi_stage_primitive { return layout{beam_table_shape, impl_param.output_layouts[1].data_type, format::get_default_format(beam_table_shape.size())}; } + static scatter_kernel_params_t get_scatter_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { + const auto primitive = impl_param.typed_desc(); + auto params = get_default_params(impl_param, is_shape_agnostic); + + auto inputs_count = 3; + + params.inputs.resize(inputs_count); + params.inputs[0] = convert_data_tensor(impl_param.input_layouts[0], tensor()); + params.inputs[1] = convert_data_tensor(impl_param.input_layouts[3], tensor()); + params.inputs[2] = convert_data_tensor(impl_param.input_layouts[4], tensor()); + params.outputs[0] = convert_data_tensor(impl_param.output_layouts[0], tensor()); + + params.axis = kernel_selector::scatter_update_axis::Y; // always update axis 2 which is KV seq_len dimension + params.mode = kernel_selector::ScatterUpdateReduction::NONE; + params.use_init_val = true; + + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx], past_seq_len, dst_idx, update_data] + std::map in_tensor_to_offset_map = { + {0, in_offsets_map.at(0)}, + {1, in_offsets_map.at(3)}, + {2, in_offsets_map.at(4)}, + }; + std::map out_tensor_to_offset_map = { + {0, in_offsets_map.at(0)}, + }; + + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); + + return params; + } + static kernel_params_t get_concat_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { const auto& primitive = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_shape_agnostic); @@ -351,7 +479,7 @@ struct kv_cache_impl : multi_stage_primitive { const auto compression_inputs = desc->get_compression_scales_inputs_num() + desc->get_compression_zp_inputs_num(); const auto beam_table_past_idx = 3 + compression_inputs; - const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, [scale_past], [zp_past], beam_table_past]] + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, [past_seq_len, dst_idx, update_data], [scale_past], [zp_past], beam_table_past]] const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present, compression_scale_present] std::map in_tensor_to_offset_map = { {0, in_offsets_map.at(beam_table_past_idx)}, // beam_table_past @@ -437,7 +565,7 @@ struct kv_cache_impl : multi_stage_primitive { params.inputs[0] = convert_data_tensor(comp_scale_past_layout); params.outputs[0] = convert_data_tensor(comp_scale_present_layout); - const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, [past_seq_len, dst_idx, update_data], [scale_past], [zp_past], beam_table_past]] const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; std::map in_tensor_to_offset_map = { @@ -454,11 +582,22 @@ struct kv_cache_impl : multi_stage_primitive { static std::unique_ptr create(const typed_program_node& arg, const kernel_impl_params& impl_param) { std::vector kernels_data; + std::vector stages; + const auto desc = impl_param.typed_desc(); + if (desc->update_kv) { + auto scatter_kernel_params = get_scatter_kernel_params(impl_param, impl_param.is_dynamic()); + auto& scatter_kernel_selector = scatter_kernel_selector_t::Instance(); + auto scatter_kernels = scatter_kernel_selector.GetBestKernels(scatter_kernel_params); + if (!scatter_kernels.empty()) { + kernels_data.push_back(scatter_kernels[0]); + stages.push_back(kv_stage::scatter_update); + } + } auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic()); auto& concat_kernel_selector = kernel_selector_t::Instance(); kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params)); + stages.push_back(kv_stage::concat); - const auto desc = impl_param.typed_desc(); const bool indirect = desc->indirect; const bool compressed = desc->compressed; const bool has_zp_input = desc->get_compression_zp_inputs_num() > 0; @@ -466,27 +605,39 @@ struct kv_cache_impl : multi_stage_primitive { auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false); auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance(); kernels_data.push_back(bt_update_kernel_selector.get_best_kernel(bt_update_kernel_params)); + stages.push_back(kv_stage::beam_table); } if (compressed) { auto dq_kernel_params = get_dq_update_kernel_params(impl_param, impl_param.is_dynamic()); auto& dq_kernel_selector = dq_kernel_selector_t::Instance(); kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params)); + stages.push_back(kv_stage::dq); auto& concat_scale_zp_kernel_selector = kernel_selector_t::Instance(); auto concat_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic()); kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_kernel_params)); + stages.push_back(kv_stage::scale_concat); if (has_zp_input) { auto concat_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic()); kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_zp_kernel_params)); + stages.push_back(kv_stage::zp_concat); } } - return std::make_unique(kernels_data); + return std::make_unique(kernels_data, stages); } void update_dispatch_data(const kernel_impl_params& impl_param) override { + if (const auto scatter_update_stage = stages.try_get_index(kv_stage::scatter_update)) { + auto scatter_kernel_params = get_scatter_kernel_params(impl_param, true); + (_kernels_data[*scatter_update_stage].update_dispatch_data_func)(scatter_kernel_params, _kernels_data[*scatter_update_stage]); + // Skip execution if indices tensor is empty + _kernels_data[*scatter_update_stage].kernels[0].skip_execution = impl_param.get_input_layout(3).count() == 0; + } + // If model loaded from cache, params are not initialized, so we create a new object and reuse it in the future + const auto concat_stage = stages.get_index(kv_stage::concat); if (_kernels_data[concat_stage].params == nullptr) { _kernels_data[concat_stage].params = std::make_shared(get_concat_kernel_params(impl_param, true)); } @@ -507,12 +658,14 @@ struct kv_cache_impl : multi_stage_primitive { // variables memory was reallocated and we have to copy past KV-cache to new memory) _kernels_data[concat_stage].kernels[1].skip_execution = true; + const auto scale_concat_stage = stages.get_index(kv_stage::scale_concat); // Update dynamic quantization parameters auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic()); (_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]); _kernels_data[scale_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(3).count() == 0; if (impl_param.typed_desc()->get_compression_zp_inputs_num() > 0) { + const auto zp_concat_stage = stages.get_index(kv_stage::zp_concat); auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic()); (_kernels_data[zp_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[zp_concat_stage]); _kernels_data[zp_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(4).count() == 0; diff --git a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h index 0467c09581f7ba..2f04f20ab55d93 100644 --- a/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/kv_cache_inst.h @@ -10,6 +10,8 @@ #include "primitive_inst.h" #include "variable.hpp" +#include + namespace cldnn { template <> @@ -22,7 +24,14 @@ struct typed_program_node : public typed_program_node_base { program_node& input() const { return get_dependency(0); } - std::vector get_shape_infer_dependencies() const override { return {}; } + std::vector get_shape_infer_dependencies() const override { + std::vector vec; + const auto desc = get_primitive(); + if (desc->trim) { + vec.push_back(desc->indirect ? 3 : 2); // past_seq_len + } + return vec; + } std::vector get_shape_info_input_layouts() const override { std::vector res; @@ -88,11 +97,18 @@ class typed_primitive_inst : public typed_primitive_inst_base; diff --git a/src/plugins/intel_gpu/src/graph/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/kv_cache.cpp index d7a6fe048ca5c1..a4c2ed3fbd5e34 100644 --- a/src/plugins/intel_gpu/src/graph/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/kv_cache.cpp @@ -10,6 +10,7 @@ #include "primitive_type_base.h" #include #include +#include "utils.hpp" namespace cldnn { GPU_DEFINE_PRIMITIVE_TYPE_ID(kv_cache) @@ -21,6 +22,38 @@ kv_cache_inst::typed_primitive_inst(network& network, const kv_cache_node& node) kv_cache_id = kv_cache_counter++; } +int64_t kv_cache_inst::compute_trim_length(const kernel_impl_params& impl_param, const kv_cache& desc) { + if (!desc.trim) + return 0; + + const size_t past_seq_len_idx = desc.indirect ? 3 : 2; + const auto mem_dep_it = impl_param.memory_deps.find(past_seq_len_idx); + if (mem_dep_it == impl_param.memory_deps.end()) + return 0; + + const auto& past_seq_len_mem = mem_dep_it->second; + const auto past_seq_len_layout = past_seq_len_mem->get_layout(); + if (past_seq_len_layout.count() == 0) + return 0; + + OPENVINO_ASSERT(past_seq_len_layout.count() == 1); + cldnn::mem_lock past_seq_len_mem_lock(past_seq_len_mem, impl_param.get_stream()); + auto past_seq_len_tensor = make_tensor(past_seq_len_layout, past_seq_len_mem_lock.data()); + const auto past_dim_updated = ov::get_tensor_data_as(past_seq_len_tensor); + + const auto& past_layout = impl_param.get_input_layout(0); + const auto past_shape = past_layout.get_partial_shape(); + const auto sequence_axis = kv_cache_inst::get_sequence_axis(desc.concat_axis, past_shape.size()); + OPENVINO_ASSERT(sequence_axis >= 0); + const auto sequence_axis_idx = static_cast(sequence_axis); + OPENVINO_ASSERT(past_shape[sequence_axis_idx].is_static()); + + const auto trim_length = past_shape[sequence_axis_idx].get_length() - past_dim_updated[0]; + OPENVINO_ASSERT(trim_length >= 0, "[GPU] past_seq_len shouldn't exceed stored sequence length"); + + return trim_length; +} + layout kv_cache_inst::calc_output_layout(const kv_cache_node& node, kernel_impl_params const& impl_param) { return impl_param.input_layouts[0]; } @@ -31,8 +64,9 @@ std::vector kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no std::vector input_shapes = {impl_param.get_input_layout(0).get(), impl_param.get_input_layout(1).get()}; + size_t input_idx = 2; if (desc->indirect) { - input_shapes.push_back(impl_param.get_input_layout(2).get()); + input_shapes.push_back(impl_param.get_input_layout(input_idx++).get()); } if (desc->compressed) { @@ -43,13 +77,17 @@ std::vector kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no } } + const auto kv_cache_trim_length = kv_cache_inst::compute_trim_length(impl_param, *desc); + std::vector output_shapes; if (desc->compressed) { + OPENVINO_ASSERT(kv_cache_trim_length == 0); // compressed kv should not do any trim ov::intel_gpu::op::KVCacheCompressed op; op.set_output_size(desc->num_outputs); op.set_concat_axis(desc->concat_axis); op.set_gather_axis(desc->gather_axis); op.set_quantization_attrs(desc->quantization_attributes); + op.set_trim(desc->trim); output_shapes = shape_infer(&op, input_shapes); } else { @@ -57,6 +95,8 @@ std::vector kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no op.set_output_size(desc->num_outputs); op.set_concat_axis(desc->concat_axis); op.set_gather_axis(desc->gather_axis); + op.set_trim(desc->trim); + op.set_trim_length(kv_cache_trim_length); output_shapes = shape_infer(&op, input_shapes); } diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index 637aa2744ed155..fc3129ba559739 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -54,12 +54,14 @@ #include "intel_gpu/runtime/memory.hpp" #include "intel_gpu/runtime/debug_configuration.hpp" #include "intel_gpu/runtime/compilation_context.hpp" +#include "intel_gpu/runtime/tensor_accessor.hpp" #include "json_object.h" #include #include #include #include +#include "utils.hpp" #ifdef ENABLE_ONEDNN_FOR_GPU #include @@ -514,7 +516,20 @@ void primitive_inst::update_shape() { } if (get_node().is_type()) { + auto& kv_inst = downcast(*this); auto desc = get_node().as().get_primitive(); + const auto trim_length = kv_cache_inst::compute_trim_length(*_impl_params, *desc); + if (trim_length > 0) { + OPENVINO_ASSERT(!(desc->indirect || desc->compressed), + "[GPU] Unsupported trim for indirect or compressed kvcache: indirect:", + desc->indirect, + " compressed:", + desc->compressed, + " trim:", + trim_length); + } + kv_inst.set_trim_length(trim_length); + auto var_mem_size = get_network().get_variable(desc->variable_info.variable_id).get_actual_mem_size(); // Need to trigger realloc_if_needed if (var_mem_size < _impl_params->get_output_layout(0).get_linear_size()) @@ -1450,6 +1465,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() { if (!get_node().is_type()) return; + auto& kv_inst = downcast(*this); _impl_params->_can_be_optimized = false; if (_impl_params->get_input_layout(0).count() == 0) { @@ -1474,6 +1490,14 @@ void primitive_inst::do_runtime_in_place_kv_cache() { GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial present_layout : " << present_layout.to_string() << std::endl; GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial past_layout : " << past_layout.to_string() << std::endl; + if (desc->trim && kv_inst.get_trim_length() > 0) { + GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " kv cache trim_length : " << kv_inst.get_trim_length() << std::endl; + auto trimmed_past_shape = past_layout.get_shape(); + trimmed_past_shape[sequence_axis] -= kv_inst.get_trim_length(); + past_layout.set_partial_shape(trimmed_past_shape); + auto past_layout_pad = past_layout.data_padding._upper_size[sequence_axis] + kv_inst.get_trim_length(); + kv_cache_inst::update_pad(past_layout, past_layout_pad, sequence_axis); + } auto max_pad = kv_cache_inst::get_max_pad(past_layout, _deps[0].first->_max_output_layout_count[0], sequence_axis, "past_layout"); const auto new_seq_len = static_cast(new_layout.get_shape()[sequence_axis]); // In chatbot scenario, when chat history must be stored in kvcache, new_seq_len may not be 1 even if max_pad is greater than 0 diff --git a/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp index 0d4da5abce1516..20ba7802654ec6 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp @@ -23,7 +23,7 @@ namespace ov::intel_gpu { namespace { void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {2, 3}); + validate_inputs_count(op, {2, 3, 4, 5}); auto inputs = p.GetInputInfo(op); int64_t rank = op->get_input_partial_shape(0).size(); auto prim = cldnn::kv_cache(layer_type_name_ID(op), @@ -31,7 +31,9 @@ void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptrget_variable()->get_info(), ov::util::normalize(op->get_concat_axis(), rank), ov::util::normalize(op->get_gather_axis(), rank), - op->get_indirect()); + op->get_indirect(), + op->get_trim(), + op->get_update_kv()); prim.num_outputs = op->get_output_size(); prim.output_data_types = get_output_data_types(op); @@ -40,7 +42,9 @@ void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {4, 5}); + OPENVINO_ASSERT(!op->get_update_kv()); // compressed kv does not support update_kv + OPENVINO_ASSERT(op->get_indirect()); // compressed kv must be indirect + validate_inputs_count(op, {4, 5, 6}); auto inputs = p.GetInputInfo(op); int64_t rank = op->get_input_partial_shape(0).size(); auto prim = cldnn::kv_cache(layer_type_name_ID(op), @@ -48,7 +52,9 @@ void CreateKVCacheCompressedOp(ProgramBuilder& p, const std::shared_ptrget_variable()->get_info(), ov::util::normalize(op->get_concat_axis(), rank), ov::util::normalize(op->get_gather_axis(), rank), - op->get_indirect()); + op->get_indirect(), + op->get_trim(), + false); prim.compressed = true; prim.quantization_attributes = op->get_quantization_attrs(); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp index dc4ce6c0d09b12..148bfd8222bcaf 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp @@ -39,6 +39,31 @@ void replace_node_unsafe(const std::shared_ptr& target, const std::sha target->clear_control_dependents(); } +std::shared_ptr make_kvcache_indirect(const std::shared_ptr& kv_cache_node, + const std::shared_ptr& gather_input_node, + const std::shared_ptr& beam_idx_node, + const int64_t gather_axis) { + OPENVINO_ASSERT(!kv_cache_node->get_update_kv()); + if (kv_cache_node->get_trim()) { + return std::make_shared(gather_input_node, + kv_cache_node->input(1).get_source_output(), + beam_idx_node, + kv_cache_node->input(2).get_source_output(), + kv_cache_node->get_variable(), + kv_cache_node->get_concat_axis(), + gather_axis, + kv_cache_node->get_output_element_type(0)); + } else { + return std::make_shared(gather_input_node, + kv_cache_node->input(1).get_source_output(), + beam_idx_node, + kv_cache_node->get_variable(), + kv_cache_node->get_concat_axis(), + gather_axis, + kv_cache_node->get_output_element_type(0)); + } +} + } // namespace namespace ov::intel_gpu { @@ -53,7 +78,9 @@ IndirectGemmOpt::IndirectGemmOpt() { return value.size() == 1 && (value[0] == 0 || value[0] == 1); })); auto gather_past = wrap_type({gather_input, beam_idx, axis_const}); - auto kv_cache = wrap_type({gather_past, any_input()}); + auto kv_cache_notrim = wrap_type({gather_past, any_input()}); + auto kv_cache_trim = wrap_type({gather_past, any_input(), any_input()}); + auto kv_cache = std::make_shared(OutputVector{kv_cache_notrim, kv_cache_trim}); auto matmul_0 = wrap_type({kv_cache, any_input()}); auto matmul_1 = wrap_type({any_input(), kv_cache}); auto matmul = std::make_shared(OutputVector{matmul_0, matmul_1}); @@ -72,13 +99,7 @@ IndirectGemmOpt::IndirectGemmOpt() { auto gather_axis = gather_node->get_axis(); ov::replace_node(gather_node, gather_input_node); - auto indirect_kv_cache = std::make_shared(gather_input_node, - kv_cache_node->input(1).get_source_output(), - beam_idx_node, - kv_cache_node->get_variable(), - kv_cache_node->get_concat_axis(), - gather_axis, - kv_cache_node->get_output_element_type(0)); + auto indirect_kv_cache = make_kvcache_indirect(kv_cache_node, gather_input_node, beam_idx_node, gather_axis); indirect_kv_cache->set_friendly_name(kv_cache_node->get_friendly_name()); ov::copy_runtime_info(kv_cache_node, indirect_kv_cache); @@ -126,8 +147,12 @@ IndirectSDPAOpt::IndirectSDPAOpt() { })); auto gather_past_0 = wrap_type({gather_input_0, beam_idx, axis_const}); auto gather_past_1 = wrap_type({gather_input_1, beam_idx, axis_const}); - auto kv_cache_0 = wrap_type({gather_past_0, any_input()}); - auto kv_cache_1 = wrap_type({gather_past_1, any_input()}); + auto kv_cache_notrim_0 = wrap_type({gather_past_0, any_input()}); + auto kv_cache_notrim_1 = wrap_type({gather_past_1, any_input()}); + auto kv_cache_trim_0 = wrap_type({gather_past_0, any_input(), any_input()}); + auto kv_cache_trim_1 = wrap_type({gather_past_1, any_input(), any_input()}); + auto kv_cache_0 = std::make_shared(OutputVector{kv_cache_notrim_0, kv_cache_trim_0}); + auto kv_cache_1 = std::make_shared(OutputVector{kv_cache_notrim_1, kv_cache_trim_1}); auto input_attn_mask = any_input(); auto input_scale = any_input(); @@ -159,21 +184,8 @@ IndirectSDPAOpt::IndirectSDPAOpt() { ov::replace_node(gather_node_0, gather_input_node_0); ov::replace_node(gather_node_1, gather_input_node_1); - auto indirect_kv_cache_0 = std::make_shared(gather_input_node_0, - kv_cache_node_0->input_value(1), - beam_idx_node, - kv_cache_node_0->get_variable(), - kv_cache_node_0->get_concat_axis(), - gather_axis_0, - kv_cache_node_0->get_output_element_type(0)); - - auto indirect_kv_cache_1 = std::make_shared(gather_input_node_1, - kv_cache_node_1->input_value(1), - beam_idx_node, - kv_cache_node_1->get_variable(), - kv_cache_node_1->get_concat_axis(), - gather_axis_1, - kv_cache_node_1->get_output_element_type(0)); + auto indirect_kv_cache_0 = make_kvcache_indirect(kv_cache_node_0, gather_input_node_0, beam_idx_node, gather_axis_0); + auto indirect_kv_cache_1 = make_kvcache_indirect(kv_cache_node_1, gather_input_node_1, beam_idx_node, gather_axis_1); indirect_kv_cache_0->set_friendly_name(kv_cache_node_0->get_friendly_name()); indirect_kv_cache_1->set_friendly_name(kv_cache_node_1->get_friendly_name()); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp index 7430dc7ca9cfa8..b6a29ba35af820 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp @@ -104,8 +104,11 @@ std::shared_ptr const ov::op::internal::DynamicQuantize::Attributes& quantization_attrs) { OutputVector kv_cache_inputs = { past_rv_node->output(0), kv_cache_node->input_value(1), - kv_cache_node->input_value(2), - past_rv_node->output(1) }; + kv_cache_node->input_value(2) }; + if (kv_cache_node->get_trim()) { + kv_cache_inputs.push_back(kv_cache_node->input_value(3)); + } + kv_cache_inputs.push_back(past_rv_node->output(1)); if (quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) @@ -113,6 +116,7 @@ std::shared_ptr auto new_kv_cache = std::make_shared(kv_cache_inputs, kv_cache_node->get_variable(), + kv_cache_node->get_trim(), kv_cache_node->get_concat_axis(), kv_cache_node->get_gather_axis(), quantization_attrs); @@ -147,16 +151,21 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi << "single_buffer_for_scales_and_zp=" << combine_scales_and_zp << "\n"; auto query = any_input(); + auto past_seq_len = any_input(); auto key_past = wrap_type(); auto key_new_token = any_input(); auto key_beam_idx = any_input(); - auto key_cache = wrap_type({key_past, key_new_token, key_beam_idx}); + auto key_cache_notrim = wrap_type({key_past, key_new_token, key_beam_idx}); + auto key_cache_trim = wrap_type({key_past, key_new_token, key_beam_idx, past_seq_len}); + auto key_cache = std::make_shared(OutputVector{key_cache_notrim, key_cache_trim}); auto value_past = wrap_type(); auto value_new_token = any_input(); auto value_beam_idx = any_input(); - auto value_cache = wrap_type({value_past, value_new_token, value_beam_idx}); + auto value_cache_notrim = wrap_type({value_past, value_new_token, value_beam_idx}); + auto value_cache_trim = wrap_type({value_past, value_new_token, value_beam_idx, past_seq_len}); + auto value_cache = std::make_shared(OutputVector{value_cache_notrim, value_cache_trim}); auto input_attn_mask = any_input(); auto input_scale = any_input(); @@ -181,6 +190,9 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi auto key_new_token_node = pattern_map.at(key_new_token).get_node_shared_ptr(); auto key_cache_node = ov::as_type_ptr(pattern_map.at(key_cache).get_node_shared_ptr()); auto value_cache_node = ov::as_type_ptr(pattern_map.at(value_cache).get_node_shared_ptr()); + if (!key_cache_node->get_indirect() || !value_cache_node->get_indirect()) { + return false; + } auto sdpa_node = ov::as_type_ptr(m.get_match_root()); auto key_past_rv_node = ov::as_type_ptr(pattern_map.at(key_past).get_node_shared_ptr()); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp index fbdb6c3b82c0df..c5961c090334f4 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.cpp @@ -17,7 +17,10 @@ #include "openvino/op/gather.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/read_value.hpp" +#include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/sink.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/strided_slice.hpp" #include "openvino/pass/graph_rewrite.hpp" #include "openvino/pass/pattern/op/label.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -38,7 +41,18 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() { auto beam_idx = wrap_type(); auto gather_past = wrap_type({gather_input, beam_idx, wrap_type()}); auto gather_convert = wrap_type({gather_past}); - auto concat_past_input = std::make_shared(OutputVector{past, convert_past, gather_past, gather_convert}); + auto dst_idx = wrap_type(); + auto gather_update = wrap_type(); + auto update_kv = wrap_type({gather_input, dst_idx, gather_update, wrap_type()}); + auto start = wrap_type(); + auto past_seq_len = any_input(); + auto stride = wrap_type(); + auto step = wrap_type(); + auto slice_axes = wrap_type(); + auto trim_input = std::make_shared(OutputVector{gather_input, gather_past, gather_convert, update_kv}); + auto trim_past = wrap_type({trim_input, start, past_seq_len, step, slice_axes}); + auto trim_past2 = wrap_type({trim_input, start, past_seq_len, stride}); + auto concat_past_input = std::make_shared(OutputVector{trim_input, trim_past, trim_past2}); auto concat = wrap_type({concat_past_input, any_input()}); auto convert_present = wrap_type({concat}); auto present_input = std::make_shared(OutputVector{concat, convert_present}); @@ -80,14 +94,103 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() { ov::copy_runtime_info(past_node, new_read_value_node); ov::replace_node(past_node, new_read_value_node); - if (pattern_map.count(gather_past) > 0) { - kv_cache_node = std::make_shared(pattern_map.at(gather_past).get_node_shared_ptr(), + const bool has_beam_idx = pattern_map.count(gather_past) > 0; + const bool has_update_kv = pattern_map.count(update_kv) > 0; + const bool has_slice = pattern_map.count(trim_past) > 0; + const bool has_strided_slice = pattern_map.count(trim_past2) > 0; + const bool has_trim = has_slice || has_strided_slice; + + const auto adjust_axis_to_positive = [&new_read_value_node](auto axis) ->std::optional { + if (axis >= 0) { + return static_cast(axis); + } else { + const auto input_rank = new_read_value_node->get_output_partial_shape(0).rank(); + if (input_rank.is_static()) { + const auto adjusted_axis = input_rank.get_interval().get_min_val() + axis; + if (adjusted_axis >= 0) { + return static_cast(adjusted_axis); + } + } + } + return std::nullopt; + }; + std::optional target_concat_axis = adjust_axis_to_positive(concat_axis); + OPENVINO_ASSERT(target_concat_axis.has_value(), "concat_axis should be valid, get: ", concat_axis); + + std::shared_ptr past_seq_len_node; + if (has_trim) { + past_seq_len_node = pattern_map.at(past_seq_len).get_node_shared_ptr(); + // StridedSlice uses multi-dim for end tensor, extract only the slice dim + if (has_strided_slice) { + const auto strided_slice = ov::as_type_ptr(concat_node->input_value(0).get_node_shared_ptr()); + const auto begin_mask = strided_slice->get_begin_mask(); + const auto end_mask = strided_slice->get_end_mask(); + // begin/end mask should be the same and only last element is 0 (being sliced) + if (begin_mask != end_mask || begin_mask.empty()) { + return false; + } + if (static_cast(std::count(begin_mask.begin(), begin_mask.end(), 1)) != (begin_mask.size() - 1) || begin_mask.back() != 0) { + return false; + } + // slice start and stride should be all 1 + const auto slice_start = ov::as_type_ptr(pattern_map.at(start).get_node_shared_ptr()); + if (const auto start_data = slice_start->cast_vector(); std::any_of(start_data.begin(), start_data.end(), [](const auto val) { + return val != 1; + })) { + return false; + } + const auto slice_stride = ov::as_type_ptr(pattern_map.at(stride).get_node_shared_ptr()); + if (const auto stride_data = slice_stride->cast_vector(); std::any_of(stride_data.begin(), stride_data.end(), [](const auto val) { + return val != 1; + })) { + return false; + } + // sliced axis should be the same with concat_axis + if (begin_mask.size() != *target_concat_axis + 1) { + return false; + } + const auto slice_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, {concat_axis}); + const auto gather_axis = ov::op::v0::Constant::create(element::i64, Shape{1}, {0}); + past_seq_len_node = std::make_shared(past_seq_len_node, slice_axis, gather_axis); + } else { + // slice start should be 0 and step should be 1 + const auto slice_start = ov::as_type_ptr(pattern_map.at(start).get_node_shared_ptr()); + if (const auto start_data = slice_start->cast_vector(); start_data.size() != 1 || start_data[0] != 0) { + return false; + } + const auto slice_step = ov::as_type_ptr(pattern_map.at(step).get_node_shared_ptr()); + if (const auto step_data = slice_step->cast_vector(); step_data.size() != 1 || step_data[0] != 1) { + return false; + } + // slice axis should be the same as concat_axis + const auto slice_axis = ov::as_type_ptr(pattern_map.at(slice_axes).get_node_shared_ptr()); + if (const auto axis_data = slice_axis->cast_vector(); + axis_data.size() != 1 || adjust_axis_to_positive(axis_data[0]) != *target_concat_axis) { + return false; + } + } + } + + const auto input0 = has_beam_idx ? pattern_map.at(gather_past).get_node_shared_ptr() : new_read_value_node; + if (has_update_kv) { + OPENVINO_ASSERT(has_trim); + kv_cache_node = std::make_shared(input0, + concat_node->input(1).get_source_output(), + past_seq_len_node, + pattern_map.at(dst_idx).get_node_shared_ptr(), + pattern_map.at(gather_update).get_node_shared_ptr(), + variable, + concat_axis, + new_read_value_node->get_output_element_type(0)); + } else if (has_trim) { + kv_cache_node = std::make_shared(input0, concat_node->input(1).get_source_output(), + past_seq_len_node, variable, concat_axis, new_read_value_node->get_output_element_type(0)); } else { - kv_cache_node = std::make_shared(new_read_value_node, + kv_cache_node = std::make_shared(input0, concat_node->input(1).get_source_output(), variable, concat_axis, diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp index 9aebc29904209a..69daf89cd76175 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp @@ -15,6 +15,7 @@ namespace ov::intel_gpu::op { KVCache::KVCache(const OutputVector& inputs, const std::shared_ptr& past_variable, bool indirect, + bool trim, int64_t concat_axis, int64_t gather_axis, const ov::element::Type output_type) @@ -22,6 +23,7 @@ KVCache::KVCache(const OutputVector& inputs, , m_concat_axis(concat_axis) , m_gather_axis(gather_axis) , m_indirect(indirect) + , m_trim(trim) , m_output_type(output_type) { m_variable = past_variable; } @@ -33,7 +35,7 @@ KVCache::KVCache(const Output& past, int64_t concat_axis, int64_t gather_axis, const ov::element::Type output_type) - : KVCache({past, new_token_data, beam_idx}, past_variable, true, concat_axis, gather_axis, output_type) { + : KVCache({past, new_token_data, beam_idx}, past_variable, true, false, concat_axis, gather_axis, output_type) { if (m_indirect) set_output_size(2); validate_and_infer_types(); @@ -44,8 +46,44 @@ KVCache::KVCache(const Output& past, const std::shared_ptr& past_variable, int64_t concat_axis, const ov::element::Type output_type) - : KVCache({past, new_token_data}, past_variable, false, concat_axis, 0, output_type) { - m_variable = past_variable; + : KVCache({past, new_token_data}, past_variable, false, false, concat_axis, 0, output_type) { + validate_and_infer_types(); +} + +KVCache::KVCache(const Output& past, + const Output& new_token_data, + const Output& past_seq_len, + const std::shared_ptr& past_variable, + int64_t concat_axis, + const ov::element::Type output_type) + : KVCache({past, new_token_data, past_seq_len}, past_variable, false, true, concat_axis, 0, output_type) { + validate_and_infer_types(); +} + +KVCache::KVCache(const Output& past, + const Output& new_token_data, + const Output& beam_idx, + const Output& past_seq_len, + const std::shared_ptr& past_variable, + int64_t concat_axis, + int64_t gather_axis, + const ov::element::Type output_type) + : KVCache({past, new_token_data, beam_idx, past_seq_len}, past_variable, true, true, concat_axis, gather_axis, output_type) { + if (m_indirect) + set_output_size(2); + validate_and_infer_types(); +} + +KVCache::KVCache(const Output& past, + const Output& new_token_data, + const Output& past_seq_len, + const Output& dst_idx, + const Output& update_data, + const std::shared_ptr& past_variable, + int64_t concat_axis, + const ov::element::Type output_type) + : KVCache({past, new_token_data, past_seq_len, dst_idx, update_data}, past_variable, false, true, concat_axis, 0, output_type) { + m_update_kv = true; validate_and_infer_types(); } @@ -53,6 +91,8 @@ bool KVCache::visit_attributes(ov::AttributeVisitor& visitor) { visitor.on_attribute("concat_axis", m_concat_axis); visitor.on_attribute("gather_axis", m_gather_axis); visitor.on_attribute("indirect", m_indirect); + visitor.on_attribute("trim", m_trim); + visitor.on_attribute("update_kv", m_update_kv); visitor.on_attribute("output_type", m_output_type); return true; } @@ -87,14 +127,43 @@ std::shared_ptr KVCache::clone_with_new_inputs(const ov::OutputVector& new m_concat_axis, m_output_type); - } else { + } else if (new_args.size() == 3) { + if (m_trim) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + m_variable, + m_concat_axis, + m_output_type); + } else { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + m_variable, + m_concat_axis, + m_gather_axis, + m_output_type); + } + } else if (new_args.size() == 4) { return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), + new_args.at(3), m_variable, m_concat_axis, m_gather_axis, m_output_type); + } else if (new_args.size() == 5) { + return std::make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + m_variable, + m_concat_axis, + m_output_type); + } else { + OPENVINO_ASSERT(false); } } @@ -104,13 +173,17 @@ std::vector shape_infer(const KVCache* op, const std::vectorget_gather_axis(); const auto& concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size()); + const auto trim_length = op->get_trim() ? op->get_trim_length() : 0; + if (trim_length > 0) { + OPENVINO_ASSERT(!op->get_indirect(), "Indirect KVCache should not perform trim"); + } // We update output shape with input1 shape by default, as input1 is always new, and in some situations, input0 shape // has zeros in some dimensions. For example to concat input0 [-1, 0, 0, 0] + input1 [-1, 4, -1, 128] along axis 2, // we could (and should) infer dim value of axis 1 and 3 in this case. if (op->get_output_size() >= 2) { out_shapes[0] = input_shapes[1]; out_shapes[0][gather_axis] = input_shapes[2][0]; - out_shapes[0][concat_axis] += input_shapes[0][concat_axis]; + out_shapes[0][concat_axis] += input_shapes[0][concat_axis] - trim_length; std::vector dims(out_shapes[0].size(), 1); dims[gather_axis] = out_shapes[0][gather_axis]; @@ -118,7 +191,7 @@ std::vector shape_infer(const KVCache* op, const std::vector shape_infer(const KVCache* op, const std::vector& past_variable, + bool trim, int64_t concat_axis, int64_t gather_axis, const QuantizationAttrs& quantization_attrs, const ov::element::Type output_type) - : KVCache(inputs, past_variable, true, concat_axis, gather_axis, output_type) + : KVCache(inputs, past_variable, true, trim, concat_axis, gather_axis, output_type) , m_compressed(true) , m_quantization_attrs(quantization_attrs) { OPENVINO_ASSERT(quantization_attrs.quantization_dt == ov::element::i8, @@ -149,11 +223,12 @@ KVCacheCompressed::KVCacheCompressed(const OutputVector& inputs, void KVCacheCompressed::validate_and_infer_types() { std::vector input_shapes = {m_variable->get_info().data_shape, get_input_partial_shape(1)}; input_shapes.push_back(get_input_partial_shape(2)); - input_shapes.push_back(get_input_partial_shape(3)); + const auto compress_input_offset = m_trim ? 4 : 3; + input_shapes.push_back(get_input_partial_shape(compress_input_offset + 0)); if (m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && m_quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) - input_shapes.push_back(get_input_partial_shape(4)); + input_shapes.push_back(get_input_partial_shape(compress_input_offset + 1)); auto shapes = shape_infer(this, input_shapes); @@ -172,6 +247,7 @@ std::shared_ptr KVCacheCompressed::clone_with_new_inputs(const ov::OutputV check_new_args_count(this, new_args); return std::make_shared(new_args, m_variable, + m_trim, m_concat_axis, m_gather_axis, m_quantization_attrs, @@ -182,6 +258,11 @@ std::vector shape_infer(const KVCacheCompressed* op, const std::vector& input_shapes) { std::vector out_shapes = shape_infer(static_cast(op), input_shapes); + const auto trim_length = op->get_trim() ? op->get_trim_length() : 0; + if (trim_length > 0) { + OPENVINO_ASSERT(!op->get_indirect(), "Compressed KVCache should not perform trim"); + } + if (op->get_output_size() >= 3) { ov::op::internal::DynamicQuantize dq_op; dq_op.set_attrs(op->get_quantization_attrs()); diff --git a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache.cpp b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache.cpp index f0120a947395cb..58ec3c5c16a25c 100644 --- a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache.cpp +++ b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/kv_cache.cpp @@ -6,11 +6,13 @@ #include "common_test_utils/ov_tensor_utils.hpp" #include "common_test_utils/ov_test_utils.hpp" #include "common_test_utils/subgraph_builders/llm_builders.hpp" +#include "openvino/core/type/element_type.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/parameter.hpp" #include "openvino/op/result.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" #include "shared_test_classes/base/utils/compare_results.hpp" +#include namespace { using ov::test::InputShape; @@ -245,6 +247,26 @@ class KVCacheTests: public ::testing::Test { } } + /** + * @brief Additional LLM KV cache trimming parameter + */ + struct kv_cache_trim_params { + /// \param Length of sequence that starts trimming + int32_t trigger_len = 0; + /// \param Length of sequence after trimmed + int32_t trim_seq = 0; + }; + + /** + * @brief Additional LLM KV cache reordering parameter + */ + struct kv_cache_reorder_params { + /// \param Source indices for reorder + std::vector src_idx; + /// \param Destination indices for reorder + std::vector dst_idx; + }; + void test_smoke_multipleIterations_stateful(bool is_caching_test, bool fuse_cache_reorder, bool build_state_initializer, @@ -254,7 +276,9 @@ class KVCacheTests: public ::testing::Test { size_t num_iter = 10, size_t num_groups = 1, bool set_state_on_each_iter = false, - int32_t initial_batch = -1) { + int32_t initial_batch = -1, + const kv_cache_trim_params* trim_params = nullptr, + const kv_cache_reorder_params* reorder_params = nullptr) { #if defined(ANDROID) GTEST_SKIP(); #endif @@ -284,6 +308,22 @@ class KVCacheTests: public ::testing::Test { ov::element::Type element_type = model_element_type; const bool stateful = true; + if (trim_params) { + OPENVINO_ASSERT(trim_params->trim_seq > 0 && trim_params->trigger_len >= trim_params->trim_seq); + } + if (reorder_params) { + OPENVINO_ASSERT(trim_params); + OPENVINO_ASSERT(reorder_params->src_idx.size() == reorder_params->dst_idx.size()); + // make sure src/dst idx within range + const auto src_idx_fit = std::all_of(reorder_params->src_idx.begin(), reorder_params->src_idx.end(), [&](const auto& idx) { + return idx >= 0 && idx < trim_params->trigger_len; + }); + OPENVINO_ASSERT(src_idx_fit); + const auto dst_idx_fit = std::all_of(reorder_params->dst_idx.begin(), reorder_params->dst_idx.end(), [&](const auto& idx) { + return idx >= 0 && idx < trim_params->trim_seq; + }); + OPENVINO_ASSERT(dst_idx_fit); + } auto model = ov::test::utils::make_llm_kv_cache_pattern(build_state_initializer ? ov::Dimension::dynamic() : batch, n_heads, @@ -293,7 +333,9 @@ class KVCacheTests: public ::testing::Test { stateful, fuse_cache_reorder, build_state_initializer && stateful, - num_groups); + num_groups, + trim_params, + reorder_params); auto ref_model = ov::test::utils::make_llm_kv_cache_pattern(build_state_initializer ? ov::Dimension::dynamic() : batch, n_heads, n_features, @@ -302,28 +344,112 @@ class KVCacheTests: public ::testing::Test { !stateful, fuse_cache_reorder, build_state_initializer && !stateful, - num_groups); + num_groups, + trim_params, + reorder_params); if (is_caching_test) { core->compile_model(model, ov::test::utils::DEVICE_GPU, properties); } + ov::Shape unit_shape = {1}; + + struct kv_cache_trim_state { + const kv_cache_trim_params& trim; + ov::Tensor seq_len; + ov::Shape seq_len_shape; + kv_cache_trim_state(const kv_cache_trim_params& trim_params) : trim(trim_params), seq_len_shape{1} {} + virtual ~kv_cache_trim_state() {} + virtual std::optional update(const size_t past_seq_len) { + if (past_seq_len >= static_cast(trim.trigger_len)) { + seq_len.data()[0] = trim.trim_seq; + return trim.trim_seq; + } else { + OPENVINO_ASSERT(past_seq_len < std::numeric_limits::max()); + seq_len.data()[0] = static_cast(past_seq_len); + return std::nullopt; + } + } + }; + struct kv_cache_reorder_state : kv_cache_trim_state { + const kv_cache_reorder_params& reorder; + ov::Tensor src_idx; + ov::Tensor dst_idx; + ov::Tensor src_idx_data; + ov::Tensor dst_idx_data; + ov::Shape src_shape; + ov::Shape dst_shape; + kv_cache_reorder_state(const kv_cache_trim_params& trim_params, + const kv_cache_reorder_params& reorder_params, + size_t batch, + size_t n_heads, + size_t n_features) + : kv_cache_trim_state(trim_params), + reorder(reorder_params), + src_shape{reorder.src_idx.size()}, + dst_shape{batch, n_heads, reorder.dst_idx.size(), n_features} { + src_idx_data = ov::Tensor(ov::element::i32, src_shape, reorder.src_idx.data()); + dst_idx_data = ov::Tensor(ov::element::i32, dst_shape); + auto dst_ptr = dst_idx_data.data(); + for (size_t b = 0; b < batch; ++b) { + for (size_t h = 0; h < n_heads; ++h) { + for (const auto slice : reorder.dst_idx) { + dst_ptr = std::fill_n(dst_ptr, n_features, slice); + } + } + } + } + ~kv_cache_reorder_state() override {} + std::optional update(const size_t past_seq_len) override { + const auto new_len = kv_cache_trim_state::update(past_seq_len); + if (new_len.has_value()) { + src_shape = {reorder.src_idx.size()}; + dst_shape[2] = reorder.dst_idx.size(); + src_idx.set_shape(src_shape); + dst_idx.set_shape(dst_shape); + src_idx_data.copy_to(src_idx); + dst_idx_data.copy_to(dst_idx); + } else { + src_shape = {0}; + dst_shape[2] = 0; + src_idx.set_shape(src_shape); + dst_idx.set_shape(dst_shape); + } + return new_len; + } + }; + std::unique_ptr extra_state; + if (reorder_params) { + extra_state = std::make_unique(*trim_params, *reorder_params, batch, n_heads, n_features); + } else if (trim_params) { + extra_state = std::make_unique(*trim_params); + } + auto compiled_model = core->compile_model(model, ov::test::utils::DEVICE_GPU, properties); - auto input0 = model->get_parameters().at(0); - auto input1 = model->get_parameters().at(1); - auto input2 = fuse_cache_reorder ? model->get_parameters().at(2) : nullptr; + size_t param_idx = 0; + auto input0 = model->get_parameters().at(param_idx++); + auto input1 = model->get_parameters().at(param_idx++); + auto input_beam_idx = fuse_cache_reorder ? model->get_parameters().at(param_idx++) : nullptr; + auto input_src_idx = reorder_params ? model->get_parameters().at(param_idx++) : nullptr; + auto input_dst_idx = reorder_params ? model->get_parameters().at(param_idx++) : nullptr; + auto input_trim = trim_params ? model->get_parameters().at(param_idx++) : nullptr; auto output0 = model->get_results().at(0); auto beam_idx_shape = ov::Shape{batch}; - auto get_ref_results = [&ref_model, fuse_cache_reorder](const ov::Tensor& kv_cache, - const ov::Tensor& new_token_data, - const ov::Tensor& matmul_data, - const ov::Tensor& beam_idx_data, - const ov::Shape& beam_idx_shape) { - auto input0 = ref_model->get_parameters().at(0); - auto input1 = ref_model->get_parameters().at(1); - auto input2 = ref_model->get_parameters().at(2); - auto input3 = fuse_cache_reorder ? ref_model->get_parameters().at(3) : nullptr; + auto get_ref_results = [&ref_model, fuse_cache_reorder, trim_params, reorder_params](const ov::Tensor& kv_cache, + const ov::Tensor& new_token_data, + const ov::Tensor& matmul_data, + const ov::Tensor& beam_idx_data, + const ov::Shape& beam_idx_shape, + const kv_cache_trim_state* extra_state) { + size_t param_idx = 0; + auto input0 = ref_model->get_parameters().at(param_idx++); + auto input1 = ref_model->get_parameters().at(param_idx++); + auto input2 = ref_model->get_parameters().at(param_idx++); + auto input_beam_idx = fuse_cache_reorder ? ref_model->get_parameters().at(param_idx++) : nullptr; + auto input_src_idx = reorder_params ? ref_model->get_parameters().at(param_idx++) : nullptr; + auto input_dst_idx = reorder_params ? ref_model->get_parameters().at(param_idx++) : nullptr; + auto input_trim = trim_params ? ref_model->get_parameters().at(param_idx++) : nullptr; std::map, ov::PartialShape> input_shapes = { {input0, kv_cache.get_shape()}, {input1, new_token_data.get_shape()}, @@ -335,8 +461,20 @@ class KVCacheTests: public ::testing::Test { {input2, matmul_data} }; if (fuse_cache_reorder) { - input_shapes[input3] = beam_idx_shape; - inputs.emplace(input3, beam_idx_data); + input_shapes[input_beam_idx] = beam_idx_shape; + inputs.emplace(input_beam_idx, beam_idx_data); + } + if (reorder_params) { + const auto& extra = static_cast(*extra_state); + input_shapes[input_src_idx] = extra.src_shape; + inputs.emplace(input_src_idx, extra.src_idx); + input_shapes[input_dst_idx] = extra.dst_shape; + inputs.emplace(input_dst_idx, extra.dst_idx); + } + if (trim_params) { + const auto& extra = *extra_state; + input_shapes[input_trim] = extra.seq_len_shape; + inputs.emplace(input_trim, extra.seq_len); } ref_model->reshape(input_shapes); return ov::test::utils::infer_on_template(ref_model, inputs); @@ -364,6 +502,18 @@ class KVCacheTests: public ::testing::Test { auto new_token_input = infer_request.get_tensor(input0); auto matmul_input = infer_request.get_tensor(input1); + if (extra_state) { + extra_state->seq_len = infer_request.get_tensor(input_trim); + infer_request.set_tensor(input_trim, extra_state->seq_len); + } + if (reorder_params) { + auto& extra = static_cast(*extra_state); + extra.src_idx = infer_request.get_tensor(input_src_idx); + extra.dst_idx = infer_request.get_tensor(input_dst_idx); + infer_request.set_tensor(input_src_idx, extra.src_idx); + infer_request.set_tensor(input_dst_idx, extra.dst_idx); + } + infer_request.set_tensor(input0, new_token_input); infer_request.set_tensor(input1, matmul_input); @@ -393,12 +543,16 @@ class KVCacheTests: public ::testing::Test { } if (fuse_cache_reorder) { - infer_request.set_tensor(input2, init_beam_idx_data_0); + infer_request.set_tensor(input_beam_idx, init_beam_idx_data_0); + } + + if (extra_state) { + cache_size = extra_state->update(cache_size).value_or(cache_size); } ref_kv_cache = ov::Tensor(element_type, kv_cache_size_initial); - auto ref_results = get_ref_results(ref_kv_cache, new_token_data, matmul_data, init_beam_idx_data_0, init_beam_idx_shape); + auto ref_results = get_ref_results(ref_kv_cache, new_token_data, matmul_data, init_beam_idx_data_0, init_beam_idx_shape, extra_state.get()); ref_kv_cache = ref_results[0]; infer_request.infer(); @@ -426,16 +580,20 @@ class KVCacheTests: public ::testing::Test { const size_t input_tokens = 1; const ov::Shape new_token_size = {batch, input_tokens, n_heads / num_groups, n_features}; - size_t context_length = cache_size + input_tokens; - for (size_t i = 0; i < num_iter; i++, context_length += input_tokens) { + for (size_t i = 0; i < num_iter; i++) { + if (extra_state) { + cache_size = extra_state->update(cache_size).value_or(cache_size); + } + size_t context_length = cache_size + input_tokens; ov::Shape matmul_in_size_loop = {batch, n_heads, input_tokens, context_length}; auto new_token_data = ov::test::utils::create_and_fill_tensor(element_type, new_token_size); auto matmul_data = ov::test::utils::create_and_fill_tensor(element_type, matmul_in_size_loop); size_t beam_idx_array_idx = i == 0 ? 2 : i % 2; if (fuse_cache_reorder) { - infer_request.set_tensor(input2, beam_idx_data_array[beam_idx_array_idx]); + infer_request.set_tensor(input_beam_idx, beam_idx_data_array[beam_idx_array_idx]); } - auto ref_results = get_ref_results(ref_kv_cache, new_token_data, matmul_data, beam_idx_data_array[beam_idx_array_idx], beam_idx_shape); + auto ref_results = + get_ref_results(ref_kv_cache, new_token_data, matmul_data, beam_idx_data_array[beam_idx_array_idx], beam_idx_shape, extra_state.get()); ref_kv_cache = ref_results[0]; new_token_input.set_shape(new_token_data.get_shape()); @@ -454,6 +612,8 @@ class KVCacheTests: public ::testing::Test { auto state_1 = infer_request.query_state()[0].get_state(); compare_tensors({ ref_kv_cache }, {state_1}); } + + cache_size = context_length; } auto state = infer_request.query_state()[0].get_state(); @@ -524,6 +684,30 @@ TEST_F(KVCacheTests, smoke_multipleIterations_stateful_with_set_state) { this->test_smoke_multipleIterations_stateful(false, true, true, 1, 2, ov::element::f16, 5, 1, true); } +TEST_F(KVCacheTests, smoke_multipleIterations_stateful_trim) { + kv_cache_trim_params trim; + trim.trigger_len = 17; + trim.trim_seq = 14; + this->test_smoke_multipleIterations_stateful(false, false, true, 1, 2, ov::element::f16, 5, 1, true, 1, &trim); +} + +TEST_F(KVCacheTests, smoke_multipleIterations_stateful_beam_trim) { + kv_cache_trim_params trim; + trim.trigger_len = 200; + trim.trim_seq = 200; + this->test_smoke_multipleIterations_stateful(false, true, true, 1, 2, ov::element::f16, 5, 1, true, 1, &trim); +} + +TEST_F(KVCacheTests, smoke_multipleIterations_stateful_trim_reorder) { + kv_cache_trim_params trim; + kv_cache_reorder_params reorder; + trim.trigger_len = 18; + trim.trim_seq = 14; + reorder.src_idx = {12, 13, 14}; + reorder.dst_idx = {10, 11, 12}; + this->test_smoke_multipleIterations_stateful(false, false, true, 1, 2, ov::element::f16, 5, 1, true, 1, &trim, &reorder); +} + class KVCacheIssueTests: public ::testing::Test { public: void test_smoke_conflicted_memory_for_two_inf_req() { diff --git a/src/plugins/intel_gpu/tests/unit/dynamic_execution/stateful_model.cpp b/src/plugins/intel_gpu/tests/unit/dynamic_execution/stateful_model.cpp index 8ece830bb4dfdf..c3c2325fd67224 100644 --- a/src/plugins/intel_gpu/tests/unit/dynamic_execution/stateful_model.cpp +++ b/src/plugins/intel_gpu/tests/unit/dynamic_execution/stateful_model.cpp @@ -207,7 +207,7 @@ TEST(stateful_model, check_dynamic_pad_for_kv_cache) { ov::Shape{}, // output shape 0, // batch_dim true), // support_neg_ind - kv_cache("concat", {input_info("gather"), input_info("present")}, info, 0, 0, false), + kv_cache("concat", {input_info("gather"), input_info("present")}, info, 0, 0, false, false, false), reorder("reorder", input_info("concat"), format::bfyx, data_types::f32)); /*output padding*/ ExecutionConfig config = get_test_default_config(engine); @@ -264,7 +264,7 @@ TEST(stateful_model, kv_cache_release) { topology topology(input_layout("past", input_lay), input_layout("present", input_lay), read_value("kv_cache", {input_info("past")}, info.variable_id, {input_kv_lay}), - kv_cache("concat", {input_info("kv_cache"), input_info("present")}, info, 0, 0, false), + kv_cache("concat", {input_info("kv_cache"), input_info("present")}, info, 0, 0, false, false, false), reorder("reorder", input_info("concat"), format::bfyx, data_types::f32)); /*output padding*/ ExecutionConfig config = get_test_default_config(engine); diff --git a/src/plugins/intel_gpu/tests/unit/shape_infer/kv_cache_si_test.cpp b/src/plugins/intel_gpu/tests/unit/shape_infer/kv_cache_si_test.cpp index 79803a4d01d7ea..a3855a4c3266e8 100644 --- a/src/plugins/intel_gpu/tests/unit/shape_infer/kv_cache_si_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/shape_infer/kv_cache_si_test.cpp @@ -47,7 +47,7 @@ TEST_P(kv_cache_test, shape_infer) { ov::op::util::VariableInfo info{p.input_layouts[0].get_partial_shape(), p.input_layouts[0].data_type, "v0"}; - auto kv_cache_prim = std::make_shared("output", input_prims_ids, info, p.concat_axis, p.gather_axis, p.indirect); + auto kv_cache_prim = std::make_shared("output", input_prims_ids, info, p.concat_axis, p.gather_axis, p.indirect, false, false); auto& kv_cache_node = prog.get_or_create(kv_cache_prim); for (size_t i = 0; i < p.input_layouts.size(); i++) { auto& input_node = prog.get_or_create(input_prims[i]); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp index 3da75a00596c4e..2c026103372e2d 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp @@ -117,6 +117,7 @@ TEST_F(TransformationTestsF, KVCacheCompression) { auto key_cache_inputs = ov::OutputVector{ key_past_compressed->output(0), key_current, beam_idx, key_past_compressed->output(1) }; auto key_cache = std::make_shared(key_cache_inputs, key_variable, + false, concat_axis, gather_axis, dq_config); @@ -129,6 +130,7 @@ TEST_F(TransformationTestsF, KVCacheCompression) { auto value_cache_inputs = ov::OutputVector{ value_past_compressed->output(0), value_current, beam_idx, value_past_compressed->output(1) }; auto value_cache = std::make_shared(value_cache_inputs, value_variable, + false, concat_axis, gather_axis, dq_config); @@ -178,6 +180,160 @@ TEST_F(TransformationTestsF, KVCacheCompression) { } } +TEST_F(TransformationTestsF, KVCacheCompressionWithPastSeqLen) { + bool causal = false; + bool with_mask = true; + bool with_scale = true; + size_t concat_axis = 2; + size_t gather_axis = 0; + ov::element::Type_t element_type = ov::element::f16; + std::vector qkv_order = {0, 1, 2, 3}; + std::shared_ptr mask = nullptr; + std::shared_ptr scale = nullptr; + ov::PartialShape input_shape = ov::PartialShape{1, 32, -1, 80}; + + { + auto query = std::make_shared(element_type, input_shape); + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{1}); + auto past_seq_len = std::make_shared(ov::element::i32, ov::PartialShape{1}); + + auto key_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v0"}); + auto key_current = std::make_shared(ov::element::f16, input_shape); + auto key_past = std::make_shared(key_variable); + auto key_cache = std::make_shared(key_past, key_current, beam_idx, past_seq_len, key_variable, concat_axis, gather_axis); + + auto value_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v1"}); + auto value_current = std::make_shared(ov::element::f16, input_shape); + auto value_past = std::make_shared(value_variable); + auto value_cache = + std::make_shared(value_past, value_current, beam_idx, past_seq_len, value_variable, concat_axis, gather_axis); + + ov::ParameterVector params{ past_seq_len, beam_idx, query, key_current, value_current }; + + if (with_mask) { + auto attn_mask = std::make_shared(element_type, ov::PartialShape::dynamic(4)); + mask = attn_mask; + params.push_back(attn_mask); + } + + if (with_mask && with_scale) { + auto scale_input = std::make_shared(element_type, ov::PartialShape{1}); + scale = scale_input; + params.push_back(scale_input); + } + + ov::OutputVector sdpa_inputs = { query, key_cache->output(0), value_cache->output(0) }; + + if (mask) { + sdpa_inputs.push_back(mask); + } + + if (scale) { + sdpa_inputs.push_back(scale); + } + + std::shared_ptr sdpa = nullptr; + sdpa = std::make_shared(sdpa_inputs, + key_cache->output(1), + causal, + gather_axis, + qkv_order, + qkv_order, + qkv_order, + ov::intel_gpu::op::SDPA::default_order(4)); + + auto result = std::make_shared(sdpa); + + ov::ResultVector results{ result }; + + model = std::make_shared(results, params); + manager.register_pass(ov::element::i8, false); + } + { + ov::op::internal::DynamicQuantize::Attributes dq_config; + dq_config.quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; + dq_config.quantization_dt = ov::element::i8; + dq_config.scale_dt = ov::element::f16; + dq_config.zp_dt = ov::element::f16; + dq_config.group_sizes = { 1, 1, 1, UINT64_MAX }; + dq_config.scales_zp_output_order = { 0, 1, 2, 3 }; + dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; + + auto query = std::make_shared(element_type, input_shape); + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{1}); + auto past_seq_len = std::make_shared(ov::element::i32, ov::PartialShape{1}); + + auto key_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v0"}); + auto key_current = std::make_shared(ov::element::f16, input_shape); + auto key_past_variable_infos = { ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::i8, "v0"}, + ov::op::util::VariableInfo{{1, 32, -1, 2}, ov::element::f16, "v0"} }; + auto key_past_compressed = std::make_shared(key_variable, key_past_variable_infos); + auto key_cache_inputs = ov::OutputVector{ key_past_compressed->output(0), key_current, beam_idx, past_seq_len, key_past_compressed->output(1) }; + auto key_cache = std::make_shared(key_cache_inputs, + key_variable, + true, + concat_axis, + gather_axis, + dq_config); + + auto value_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v1"}); + auto value_current = std::make_shared(ov::element::f16, input_shape); + auto value_past_variable_infos = { ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::i8, "v1"}, + ov::op::util::VariableInfo{{1, 32, -1, 2}, ov::element::f16, "v1"} }; + auto value_past_compressed = std::make_shared(value_variable, value_past_variable_infos); + auto value_cache_inputs = ov::OutputVector{value_past_compressed->output(0), value_current, beam_idx, past_seq_len, value_past_compressed->output(1)}; + auto value_cache = std::make_shared(value_cache_inputs, + value_variable, + true, + concat_axis, + gather_axis, + dq_config); + + ov::ParameterVector params{ past_seq_len, beam_idx, query, key_current, value_current }; + + if (with_mask) { + auto attn_input = std::make_shared(element_type, ov::PartialShape::dynamic(4)); + mask = attn_input; + params.push_back(attn_input); + } + + if (with_mask && with_scale) { + auto scale_input = std::make_shared(element_type, ov::PartialShape{1}); + scale = scale_input; + params.push_back(scale_input); + } + + ov::OutputVector sdpa_inputs = { query, key_cache->output(0), value_cache->output(0) }; + if (mask) { + sdpa_inputs.push_back(mask); + } + + if (scale) { + sdpa_inputs.push_back(scale); + } + + sdpa_inputs.push_back(key_cache->output(2)); + sdpa_inputs.push_back(value_cache->output(2)); + + std::shared_ptr sdpa = nullptr; + sdpa = std::make_shared(sdpa_inputs, + key_cache->output(1), + causal, + gather_axis, + qkv_order, + qkv_order, + qkv_order, + ov::intel_gpu::op::SDPA::default_order(4), + dq_config); + + auto result = std::make_shared(sdpa); + + ov::ResultVector results{ result }; + + model_ref = std::make_shared(results, params); + } +} + TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { bool causal = false; bool with_mask = true; @@ -273,6 +429,7 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { auto key_cache_inputs = ov::OutputVector{ key_past_compressed->output(0), key_current, beam_idx, key_past_compressed->output(1) }; auto key_cache = std::make_shared(key_cache_inputs, key_variable, + false, concat_axis, gather_axis, dq_config); @@ -291,6 +448,7 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { auto value_cache_inputs = ov::OutputVector{ value_past_compressed->output(0), value_current, beam_idx, value_past_compressed->output(1) }; auto value_cache = std::make_shared(value_cache_inputs, value_variable, + false, concat_axis, gather_axis, dq_config); @@ -414,6 +572,7 @@ TEST_F(TransformationTestsF, NewTestKVCacheCompression) { auto key_cache_inputs = ov::OutputVector{key_past_compressed->output(0), key_current_vs->output(2), beam_idx, key_past_compressed->output(1)}; auto key_cache = std::make_shared(key_cache_inputs, key_variable, + false, concat_axis, gather_axis, dq_config); @@ -435,6 +594,7 @@ TEST_F(TransformationTestsF, NewTestKVCacheCompression) { auto value_cache_inputs = ov::OutputVector{value_past_compressed->output(0), value_current_vs->output(2), beam_idx, value_past_compressed->output(1)}; auto value_cache = std::make_shared(value_cache_inputs, value_variable, + false, concat_axis, gather_axis, dq_config); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_fusion_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_fusion_test.cpp index c0c73bbf6d5eb7..00c3924e58e7f0 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_fusion_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_fusion_test.cpp @@ -22,6 +22,7 @@ #include "openvino/op/parameter.hpp" #include "openvino/op/result.hpp" #include "openvino/op/sink.hpp" +#include "openvino/op/slice.hpp" #include "openvino/op/variadic_split.hpp" #include "intel_gpu/op/kv_cache.hpp" @@ -116,3 +117,32 @@ TEST_F(TransformationTestsF, KVCacheFusionTest3) { model_ref = std::make_shared(ov::ResultVector{result}, ov::ParameterVector{parameter, beam_idx}); } } + +TEST_F(TransformationTestsF, KVCacheFusionTest4) { + { + auto variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v0"}); + auto past = std::make_shared(variable); + auto parameter = std::make_shared(ov::element::f16, ov::PartialShape{1, 32, -1, 80}); + auto past_seq_len = std::make_shared(ov::element::i64, ov::PartialShape{1}); + auto start = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto step = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto slice_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + auto slice = std::make_shared(past, start, past_seq_len, step, slice_axis); + auto concat = std::make_shared(ov::OutputVector{slice, parameter}, 2); + auto present = std::make_shared(concat, variable); + auto result = std::make_shared(concat); + + model = std::make_shared(ov::ResultVector{result}, ov::SinkVector{present}, ov::ParameterVector{parameter, past_seq_len}); + manager.register_pass(); + } + { + auto variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v0"}); + auto parameter = std::make_shared(ov::element::f16, ov::PartialShape{1, 32, -1, 80}); + auto past_seq_len = std::make_shared(ov::element::i64, ov::PartialShape{1}); + auto past = std::make_shared(variable); + auto kv_cache = std::make_shared(past, parameter, past_seq_len, variable, 2, ov::element::f16); + auto result = std::make_shared(kv_cache); + + model_ref = std::make_shared(ov::ResultVector{result}, ov::ParameterVector{parameter, past_seq_len}); + } +} diff --git a/src/tests/test_utils/common_test_utils/include/common_test_utils/subgraph_builders/llm_builders.hpp b/src/tests/test_utils/common_test_utils/include/common_test_utils/subgraph_builders/llm_builders.hpp index 9077f1db1cd3dd..04ab1234f306b9 100644 --- a/src/tests/test_utils/common_test_utils/include/common_test_utils/subgraph_builders/llm_builders.hpp +++ b/src/tests/test_utils/common_test_utils/include/common_test_utils/subgraph_builders/llm_builders.hpp @@ -127,6 +127,7 @@ ov::ParameterVector form_sdpa_params(ov::Dimension batch, * @param fuse_cache_reorder Whether to fuse cache reorder * @param build_state_initializer Whether to build state initializer * @param num_groups Number of groups for GQA + * @param kv_cache_reorder Whether to do additional LLM KV cache reordering * @return Shared pointer to the created model */ std::shared_ptr make_llm_kv_cache_pattern(ov::Dimension batch = ov::Dimension::dynamic(), @@ -137,7 +138,9 @@ std::shared_ptr make_llm_kv_cache_pattern(ov::Dimension batch = ov::D bool stateful = false, bool fuse_cache_reorder = false, bool build_state_initializer = false, - size_t num_groups = 1); + size_t num_groups = 1, + bool kv_cache_trim = false, + bool kv_cache_reorder = false); /** * @brief Creates an LLM KV cache pattern with Scaled Dot Product Attention (SDPA) diff --git a/src/tests/test_utils/common_test_utils/src/subgraph_builders/llm_builders.cpp b/src/tests/test_utils/common_test_utils/src/subgraph_builders/llm_builders.cpp index 52ab91e35d2b9c..11b780ac660cec 100644 --- a/src/tests/test_utils/common_test_utils/src/subgraph_builders/llm_builders.cpp +++ b/src/tests/test_utils/common_test_utils/src/subgraph_builders/llm_builders.cpp @@ -28,8 +28,10 @@ #include "openvino/op/reshape.hpp" #include "openvino/op/result.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/scatter_elements_update.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/make_stateful.hpp" @@ -232,7 +234,13 @@ std::shared_ptr make_llm_kv_cache_pattern(ov::Dimension batch, bool stateful, bool fuse_cache_reorder, bool build_state_initializer, - size_t num_groups) { + size_t num_groups, + bool kv_cache_trim, + bool kv_cache_reorder) { + if (kv_cache_reorder) { + OPENVINO_ASSERT(kv_cache_trim); + } + ov::PartialShape kv_cache_size = {batch, n_heads / num_groups, -1, n_features}; ov::PartialShape new_token_size = {batch, -1, n_heads / num_groups, n_features}; ov::PartialShape matmul_in_size = {batch, n_heads, -1, -1}; @@ -252,6 +260,32 @@ std::shared_ptr make_llm_kv_cache_pattern(ov::Dimension batch, params.push_back(in_beam_idx); concat_input = make_kv_rearrange(in_kv_prev, in_beam_idx); } + auto context_axis_const = ov::op::v0::Constant::create(ov::element::i32, {1}, {2}); + if (kv_cache_reorder) { + OPENVINO_ASSERT(n_heads.is_static() && n_features.is_static()); + ov::PartialShape src_shape = {-1}; + auto in_src_idx = std::make_shared(ov::element::i32, src_shape); + in_src_idx->set_friendly_name("src_idx"); + params.push_back(in_src_idx); + // dst_idx has to be param, not const! + ov::PartialShape dst_shape = {batch, n_heads, -1, n_features}; + auto in_dst_idx = std::make_shared(ov::element::i32, dst_shape); + in_dst_idx->set_friendly_name("dst_idx"); + params.push_back(in_dst_idx); + auto updates = std::make_shared(concat_input, in_src_idx, context_axis_const); + concat_input = + std::make_shared(concat_input, in_dst_idx, updates, context_axis_const); + } + if (kv_cache_trim) { + ov::PartialShape unit_shape = {1}; + auto seq_len = std::make_shared(ov::element::i32, unit_shape); + seq_len->set_friendly_name("seq_len"); + params.push_back(seq_len); + auto zero_const = ov::op::v0::Constant::create(ov::element::i32, {1}, {0}); + auto one_const = ov::op::v0::Constant::create(ov::element::i32, {1}, {1}); + concat_input = + std::make_shared(concat_input, zero_const, seq_len, one_const, context_axis_const); + } auto transpose_const = ov::op::v0::Constant::create(ov::element::i32, {new_token_size.size()}, {0, 2, 1, 3}); auto transpose = std::make_shared(in_new_token, transpose_const);