Skip to content

Commit 6b1cf4b

Browse files
Merge branch 'master' into gpu_onednn_grouped_gemm2
2 parents 7d867b0 + c694fbc commit 6b1cf4b

File tree

18 files changed

+1037
-155
lines changed

18 files changed

+1037
-155
lines changed

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,33 +55,34 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config,
5555
auto adaptive_rkv_evictable_sizes = pattern::any_input(pattern::has_static_rank());
5656
auto adaptive_rkv_diversity_block_set_indices = pattern::any_input(pattern::has_static_rank());
5757
auto adaptive_rkv_diversity_block_set_indices_begins = pattern::any_input(pattern::has_static_rank());
58-
59-
auto result =
60-
pattern::wrap_type<ov::op::PagedAttentionExtension>({Q,
61-
K,
62-
V,
63-
key_cache_0,
64-
value_cache_0,
65-
past_lens,
66-
subsequence_begins,
67-
block_indices,
68-
block_indices_begins,
69-
scale,
70-
sliding_window,
71-
alibi_slopes,
72-
max_context_len,
73-
score_aggregation_window,
74-
rotated_block_indices,
75-
rotation_deltas,
76-
rotation_trig_lut,
77-
xattention_threshold,
78-
xattention_block_size,
79-
xattention_stride,
80-
sinks,
81-
adaptive_rkv_start_size,
82-
adaptive_rkv_evictable_sizes,
83-
adaptive_rkv_diversity_block_set_indices,
84-
adaptive_rkv_diversity_block_set_indices_begins});
58+
auto token_type_ids = pattern::any_input(pattern::has_static_rank());
59+
60+
auto result = pattern::wrap_type<ov::op::PagedAttentionExtension>({Q,
61+
K,
62+
V,
63+
key_cache_0,
64+
value_cache_0,
65+
past_lens,
66+
subsequence_begins,
67+
block_indices,
68+
block_indices_begins,
69+
scale,
70+
sliding_window,
71+
alibi_slopes,
72+
max_context_len,
73+
score_aggregation_window,
74+
rotated_block_indices,
75+
rotation_deltas,
76+
rotation_trig_lut,
77+
xattention_threshold,
78+
xattention_block_size,
79+
xattention_stride,
80+
sinks,
81+
adaptive_rkv_start_size,
82+
adaptive_rkv_evictable_sizes,
83+
adaptive_rkv_diversity_block_set_indices,
84+
adaptive_rkv_diversity_block_set_indices_begins,
85+
token_type_ids});
8586
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
8687
const auto pa_op = m.get_match_root();
8788
auto key_cache = ov::as_type_ptr<v0::Parameter>(pa_op->get_input_node_shared_ptr(3));

src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "openvino/op/bitwise_and.hpp"
1414
#include "openvino/op/broadcast.hpp"
1515
#include "openvino/op/concat.hpp"
16+
#include "openvino/op/convert.hpp"
1617
#include "openvino/op/divide.hpp"
1718
#include "openvino/op/gather.hpp"
1819
#include "openvino/op/greater.hpp"
@@ -197,6 +198,18 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(
197198
return res_alibi_slopes;
198199
}
199200

201+
static std::shared_ptr<ov::Node> handle_gemma3_token_type_ids(
202+
const std::map<std::string, std::shared_ptr<v0::Parameter>>& optional_model_wide_params) {
203+
if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) {
204+
auto param = optional_model_wide_params.at("token_type_ids");
205+
if (param->get_element_type() != ov::element::i32) {
206+
return std::make_shared<v0::Convert>(param, ov::element::i32);
207+
}
208+
return param;
209+
}
210+
return v0::Constant::create(ov::element::i32, ov::Shape{0}, {});
211+
}
212+
200213
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sliding_window_pattern() {
201214
auto offset = wrap_type<v0::Constant>();
202215
auto t196 = wrap_type<v1::Add>({any_input(), offset});
@@ -216,7 +229,7 @@ static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sli
216229
return {mask, offset};
217230
}
218231

219-
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gpt_oss_sliding_window_pattern() {
232+
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gptoss_gemma3_sliding_window_pattern() {
220233
auto q_idx = any_input();
221234
auto kv_idx = any_input();
222235

@@ -393,9 +406,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(
393406
std::shared_ptr<ov::Node> phi3_mask, phi3_offset;
394407
std::tie(phi3_mask, phi3_offset) = phi3_sliding_window_pattern();
395408

396-
// gpt-oss case
397-
std::shared_ptr<ov::Node> gpt_oss_mask, gpt_oss_offset;
398-
std::tie(gpt_oss_mask, gpt_oss_offset) = gpt_oss_sliding_window_pattern();
409+
// gpt-oss and gemma3 cases
410+
std::shared_ptr<ov::Node> gptoss_gemma3_mask, gptoss_gemma3_offset;
411+
std::tie(gptoss_gemma3_mask, gptoss_gemma3_offset) = gptoss_gemma3_sliding_window_pattern();
399412

400413
// Scale's shape limitations according to SDPA specification
401414
auto scale_predicate = [=](const Output<Node>& output) -> bool {
@@ -414,7 +427,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(
414427
general_alibi_mask,
415428
jais_alibi_mask,
416429
baichuan2_13b_alibi_mask,
417-
gpt_oss_mask,
430+
gptoss_gemma3_mask,
418431
any_input()});
419432

420433
auto sdpa_with_4_inputs = wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
@@ -425,6 +438,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(
425438

426439
auto sdpa_variants = std::make_shared<Or>(OutputVector{sdpa_with_4_inputs, sdpa_with_5_inputs, sdpa_with_6_inputs});
427440

441+
// Shared flag to track whether the model is Gemma3, set when any layer matches
442+
// the gptoss_gemma3 sliding window pattern. Combined with the token_type_ids check,
443+
// this uniquely identifies Gemma3 (gpt-oss shares the pattern but lacks token_type_ids).
444+
auto is_gptoss_gemma3 = std::make_shared<bool>(false);
445+
428446
ov::matcher_pass_callback callback = [=,
429447
&kv_parameters,
430448
&model_wide_params,
@@ -602,9 +620,10 @@ ov::pass::StateManagementPattern::StateManagementPattern(
602620
offset = std::make_shared<v0::Convert>(offset, element::i32);
603621
}
604622
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
605-
} else if (pattern_map.count(gpt_oss_offset)) {
606-
auto offset = pattern_map.at(gpt_oss_offset).get_node_shared_ptr();
607-
if (pattern_map.at(gpt_oss_offset).get_partial_shape().rank() != 0) {
623+
} else if (pattern_map.count(gptoss_gemma3_offset)) {
624+
*is_gptoss_gemma3 = true;
625+
auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr();
626+
if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) {
608627
offset = std::make_shared<v15::Squeeze>(offset);
609628
}
610629
if (offset->get_element_type() != element::i32) {
@@ -737,6 +756,13 @@ ov::pass::StateManagementPattern::StateManagementPattern(
737756
}
738757
OPENVINO_ASSERT(pa_arguments.size() == 25);
739758

759+
if (*is_gptoss_gemma3) {
760+
pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
761+
} else {
762+
pa_arguments.insert(pa_arguments.begin() + 25, v0::Constant::create(element::i32, Shape{0}, {}));
763+
}
764+
OPENVINO_ASSERT(pa_arguments.size() == 26);
765+
740766
auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);
741767
paged_attention->get_rt_info()[NUM_K_HEADS] = num_k_heads;
742768
paged_attention->get_rt_info()[K_HEAD_SIZE] = k_head_size;

src/common/transformations/tests/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 60 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -121,33 +121,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
121121
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
122122
auto adaptive_rkv_diversity_block_set_indices_begins =
123123
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
124+
auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::Shape{0});
124125

125-
auto pa = std::make_shared<op::PagedAttentionExtension>(
126-
OutputVector{Q,
127-
K,
128-
V,
129-
key_cache_0,
130-
value_cache_0,
131-
past_lens,
132-
subsequence_begins,
133-
block_indices,
134-
block_indices_begins,
135-
scale,
136-
sliding_window,
137-
alibi_slopes,
138-
max_context_len,
139-
score_aggregation_window,
140-
rotated_block_indices,
141-
rotation_deltas,
142-
rotation_trig_lut,
143-
xattention_threshold,
144-
xattention_block_size,
145-
xattention_stride,
146-
sinks,
147-
adaptive_rkv_start_size,
148-
adaptive_rkv_evictable_sizes,
149-
adaptive_rkv_diversity_block_set_indices,
150-
adaptive_rkv_diversity_block_set_indices_begins});
126+
auto pa =
127+
std::make_shared<op::PagedAttentionExtension>(OutputVector{Q,
128+
K,
129+
V,
130+
key_cache_0,
131+
value_cache_0,
132+
past_lens,
133+
subsequence_begins,
134+
block_indices,
135+
block_indices_begins,
136+
scale,
137+
sliding_window,
138+
alibi_slopes,
139+
max_context_len,
140+
score_aggregation_window,
141+
rotated_block_indices,
142+
rotation_deltas,
143+
rotation_trig_lut,
144+
xattention_threshold,
145+
xattention_block_size,
146+
xattention_stride,
147+
sinks,
148+
adaptive_rkv_start_size,
149+
adaptive_rkv_evictable_sizes,
150+
adaptive_rkv_diversity_block_set_indices,
151+
adaptive_rkv_diversity_block_set_indices_begins,
152+
token_type_ids});
151153
pa->get_rt_info()["num_k_heads"] = numKeyHeads;
152154
pa->get_rt_info()["k_head_size"] = keyHeadSize;
153155
pa->get_rt_info()["num_v_heads"] = numValueHeads;
@@ -174,7 +176,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
174176
adaptive_rkv_start_size,
175177
adaptive_rkv_evictable_sizes,
176178
adaptive_rkv_diversity_block_set_indices,
177-
adaptive_rkv_diversity_block_set_indices_begins});
179+
adaptive_rkv_diversity_block_set_indices_begins,
180+
token_type_ids});
178181

179182
if (isIRKVCacheF16) {
180183
model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name());
@@ -254,33 +257,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
254257
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
255258
auto adaptive_rkv_diversity_block_set_indices_begins =
256259
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
260+
auto token_type_ids = std::make_shared<v0::Parameter>(ov::element::i32, ov::Shape{0});
257261

258-
auto pa = std::make_shared<op::PagedAttentionExtension>(
259-
OutputVector{Q,
260-
K,
261-
V,
262-
key_cache_0,
263-
value_cache_0,
264-
past_lens,
265-
subsequence_begins,
266-
block_indices,
267-
block_indices_begins,
268-
scale,
269-
sliding_window,
270-
alibi_slopes,
271-
max_context_len,
272-
score_aggregation_window,
273-
rotated_block_indices,
274-
rotation_deltas,
275-
rotation_trig_lut,
276-
xattention_threshold,
277-
xattention_block_size,
278-
xattention_stride,
279-
sinks,
280-
adaptive_rkv_start_size,
281-
adaptive_rkv_evictable_sizes,
282-
adaptive_rkv_diversity_block_set_indices,
283-
adaptive_rkv_diversity_block_set_indices_begins});
262+
auto pa =
263+
std::make_shared<op::PagedAttentionExtension>(OutputVector{Q,
264+
K,
265+
V,
266+
key_cache_0,
267+
value_cache_0,
268+
past_lens,
269+
subsequence_begins,
270+
block_indices,
271+
block_indices_begins,
272+
scale,
273+
sliding_window,
274+
alibi_slopes,
275+
max_context_len,
276+
score_aggregation_window,
277+
rotated_block_indices,
278+
rotation_deltas,
279+
rotation_trig_lut,
280+
xattention_threshold,
281+
xattention_block_size,
282+
xattention_stride,
283+
sinks,
284+
adaptive_rkv_start_size,
285+
adaptive_rkv_evictable_sizes,
286+
adaptive_rkv_diversity_block_set_indices,
287+
adaptive_rkv_diversity_block_set_indices_begins,
288+
token_type_ids});
284289
pa->get_rt_info()["num_k_heads"] = numKeyHeads;
285290
pa->get_rt_info()["k_head_size"] = keyHeadSize;
286291
pa->get_rt_info()["num_v_heads"] = numValueHeads;
@@ -307,7 +312,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
307312
adaptive_rkv_start_size,
308313
adaptive_rkv_evictable_sizes,
309314
adaptive_rkv_diversity_block_set_indices,
310-
adaptive_rkv_diversity_block_set_indices_begins});
315+
adaptive_rkv_diversity_block_set_indices_begins,
316+
token_type_ids});
311317
}
312318
ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig;
313319
cacheConfig.keyCacheBlockSize = blockSize[0];

0 commit comments

Comments
 (0)