Skip to content

Commit 7d38f6f

Browse files
committed
Enable multi head size support for KV cache
In continue batching, head size for key and value will be different. Add support for it for sdpa.
1 parent d119656 commit 7d38f6f

20 files changed

+466
-300
lines changed

src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct paged_attention : public primitive_base<paged_attention> {
3535

3636
auto rhs_casted = downcast<const paged_attention>(rhs);
3737

38-
return head_size == rhs_casted.head_size &&
38+
return k_head_size == rhs_casted.k_head_size &&
39+
v_head_size == rhs_casted.v_head_size &&
3940
heads_num == rhs_casted.heads_num &&
4041
kv_heads_num == rhs_casted.kv_heads_num &&
4142
sliding_window == rhs_casted.sliding_window &&
@@ -46,7 +47,8 @@ struct paged_attention : public primitive_base<paged_attention> {
4647

4748
void save(BinaryOutputBuffer& ob) const override {
4849
primitive_base<paged_attention>::save(ob);
49-
ob << head_size;
50+
ob << k_head_size;
51+
ob << v_head_size;
5052
ob << heads_num;
5153
ob << kv_heads_num;
5254
ob << has_alibi;
@@ -63,7 +65,8 @@ struct paged_attention : public primitive_base<paged_attention> {
6365

6466
void load(BinaryInputBuffer& ib) override {
6567
primitive_base<paged_attention>::load(ib);
66-
ib >> head_size;
68+
ib >> k_head_size;
69+
ib >> v_head_size;
6770
ib >> heads_num;
6871
ib >> kv_heads_num;
6972
ib >> has_alibi;
@@ -82,7 +85,8 @@ struct paged_attention : public primitive_base<paged_attention> {
8285
}
8386

8487
std::optional<float> scale_val{};
85-
size_t head_size = 0;
88+
size_t k_head_size = 0;
89+
size_t v_head_size = 0;
8690
size_t heads_num = 0;
8791
size_t kv_heads_num = 0;
8892
size_t sliding_window = 0;

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
640640
kernel_selector::sdpa_configuration config;
641641

642642
const auto desc = impl_param.typed_desc<paged_attention>();
643-
config.head_size = desc->head_size;
643+
config.k_head_size = desc->k_head_size;
644+
config.v_head_size = desc->v_head_size;
644645
config.heads_num = desc->heads_num;
645646
config.kv_heads_num = desc->kv_heads_num;
646647
config.has_alibi_input = desc->has_alibi;

src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
266266
}
267267

268268
if (query_shape[query_shape.size() - 1].is_static())
269-
config.head_size = query_shape[query_shape.size() - 1].get_length();
269+
config.k_head_size = query_shape[query_shape.size() - 1].get_length();
270270

271271
config.is_causal = desc->is_causal;
272272

src/plugins/intel_gpu/src/graph/paged_attention.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no
2121

2222
template<typename ShapeType>
2323
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
24-
auto data_layout = impl_param.get_input_layout(0);
24+
auto q_layout = impl_param.get_input_layout(0);
25+
auto v_layout = impl_param.get_input_layout(2);
26+
auto data_layout = q_layout;
27+
28+
if (v_layout.is_static()) {
29+
ShapeType v_shape = v_layout.get_shape();
30+
data_layout = layout{v_shape, data_layout.data_type, data_layout.format};
31+
}
32+
2533
data_layout.data_padding = padding();
2634

2735
const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
2836
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
2937
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
3038
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());
31-
3239
std::vector<layout> output_layouts{ data_layout };
3340

3441
const auto& desc = impl_param.typed_desc<paged_attention>();
@@ -67,7 +74,8 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
6774

6875
json_composite paged_attention_info;
6976
paged_attention_info.add("paged_attention_block_size", desc->block_size);
70-
paged_attention_info.add("head_size", desc->head_size);
77+
paged_attention_info.add("k_head_size", desc->k_head_size);
78+
paged_attention_info.add("v_head_size", desc->v_head_size);
7179
paged_attention_info.add("heads_num", desc->heads_num);
7280
paged_attention_info.add("kv_heads_num", desc->kv_heads_num);
7381
paged_attention_info.add("scale", desc->scale_val.value_or(1.0f));
@@ -85,7 +93,8 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti
8593
: parent(network, node) {
8694
const auto desc = node.get_primitive();
8795

88-
const auto head_size = desc->head_size;
96+
const auto k_head_size = desc->k_head_size;
97+
const auto v_head_size = desc->v_head_size;
8998
const auto heads_num = desc->heads_num;
9099
const auto kv_heads_num = desc->kv_heads_num;
91100
const auto pa_block_size = desc->block_size;
@@ -97,6 +106,7 @@ paged_attention_inst::typed_primitive_inst(network& network, const paged_attenti
97106
}
98107

99108
OPENVINO_ASSERT(heads_num % kv_heads_num == 0);
100-
OPENVINO_ASSERT(head_size % pa_block_size == 0);
109+
OPENVINO_ASSERT(k_head_size % pa_block_size == 0);
110+
OPENVINO_ASSERT(v_head_size % pa_block_size == 0);
101111
}
102112
} // namespace cldnn

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ void primitive_inst::update_shape() {
404404
input_shape_changed = true;
405405
}
406406

407-
if (!_node->is_type<kv_cache>() && !input_shape_changed && _impl_params->get_output_layout().is_static())
407+
if (!_node->is_type<kv_cache>() && !_node->is_type<strided_slice>() && !input_shape_changed && _impl_params->get_output_layout().is_static())
408408
return;
409409

410410
std::vector<event::ptr> dependencies_events;
@@ -456,7 +456,6 @@ void primitive_inst::update_shape() {
456456

457457
_impl_params->memory_deps = memory_deps;
458458

459-
460459
auto new_layouts = _node->type()->calc_output_layouts(*_node, *_impl_params);
461460
for (size_t idx = 0; idx != new_layouts.size(); ++idx) {
462461
auto& new_layout = new_layouts[idx];

0 commit comments

Comments
 (0)