Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1abe8ec
WIP
p-wysocki Feb 9, 2026
34763dd
WIP
p-wysocki Feb 9, 2026
e7f8238
WIP
p-wysocki Feb 9, 2026
daef794
WIP
p-wysocki Feb 12, 2026
fdb3a73
Add tests
p-wysocki Feb 12, 2026
829c430
initial clenaup
p-wysocki Feb 12, 2026
a04d165
Set input as optional
p-wysocki Feb 12, 2026
96188fb
Correct tests
p-wysocki Feb 12, 2026
4d9f607
Remove reshape from graph
p-wysocki Feb 12, 2026
1f6a6d1
Remove debug prints
p-wysocki Feb 12, 2026
2bf68b5
Clenaup
p-wysocki Feb 13, 2026
bcbb855
Merge branch 'master' into attn_idea_2
p-wysocki Feb 13, 2026
ed3374d
Sliding window working
p-wysocki Feb 18, 2026
0fd5001
Move sw to gptoss logic
p-wysocki Feb 18, 2026
5fbbf5e
Working, with debug prints
p-wysocki Feb 18, 2026
81ab320
Cleanup
p-wysocki Feb 18, 2026
4e0d5ac
Merge branch 'attn_idea_2' of https://github.com/p-wysocki/openvino i…
p-wysocki Feb 18, 2026
5f5af24
Cleanup
p-wysocki Feb 18, 2026
362cb80
update copyright
p-wysocki Feb 18, 2026
7841c70
Fix transformation tests, add new one
p-wysocki Feb 18, 2026
9810dc3
Fix convert input tests
p-wysocki Feb 18, 2026
890804b
Fix clang
p-wysocki Feb 18, 2026
6a62dda
Fix smoke tests
p-wysocki Feb 18, 2026
dfc6e1f
Fix smoke test
p-wysocki Feb 18, 2026
a412f6c
Update GPU input count
p-wysocki Feb 25, 2026
810130d
CR
p-wysocki Mar 2, 2026
acbd73e
Add token_type_ids to gemma only
p-wysocki Mar 2, 2026
2da2303
Fix gpu test
p-wysocki Mar 2, 2026
03e9935
Merge branch 'master' into attn_idea_2
p-wysocki Mar 10, 2026
e2347e1
Merge branch 'master' of https://github.com/openvinotoolkit/openvino …
p-wysocki Mar 11, 2026
7176099
Fix conflict issues
p-wysocki Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,34 @@ ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& config, Upda
auto adaptive_rkv_evictable_sizes = pattern::any_input(pattern::has_static_rank());
auto adaptive_rkv_diversity_block_set_indices = pattern::any_input(pattern::has_static_rank());
auto adaptive_rkv_diversity_block_set_indices_begins = pattern::any_input(pattern::has_static_rank());

auto result =
pattern::wrap_type<ov::op::PagedAttentionExtension>({Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins});
auto token_type_ids = pattern::any_input(pattern::has_static_rank());

auto result = pattern::wrap_type<ov::op::PagedAttentionExtension>({Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids});
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) {
const auto pa_op = m.get_match_root();
auto key_cache = ov::as_type_ptr<v0::Parameter>(pa_op->get_input_node_shared_ptr(3));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "openvino/op/bitwise_and.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
Expand Down Expand Up @@ -197,6 +198,15 @@ static std::shared_ptr<ov::Node> handle_baichuan2_13b_alibi(
return res_alibi_slopes;
}

static std::shared_ptr<ov::Node> handle_gemma3_token_type_ids(
const std::map<std::string, std::shared_ptr<v0::Parameter>>& optional_model_wide_params) {
if (optional_model_wide_params.find("token_type_ids") != optional_model_wide_params.end()) {
auto param = optional_model_wide_params.at("token_type_ids");
return std::make_shared<v0::Convert>(param, ov::element::i32);
}
return v0::Constant::create(ov::element::i32, ov::Shape{0}, {});
}

static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sliding_window_pattern() {
auto offset = wrap_type<v0::Constant>();
auto t196 = wrap_type<v1::Add>({any_input(), offset});
Expand All @@ -216,7 +226,7 @@ static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> phi3_sli
return {mask, offset};
}

static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gpt_oss_sliding_window_pattern() {
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gptoss_gemma3_sliding_window_pattern() {
auto q_idx = any_input();
auto kv_idx = any_input();

Expand Down Expand Up @@ -393,9 +403,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(
std::shared_ptr<ov::Node> phi3_mask, phi3_offset;
std::tie(phi3_mask, phi3_offset) = phi3_sliding_window_pattern();

// gpt-oss case
std::shared_ptr<ov::Node> gpt_oss_mask, gpt_oss_offset;
std::tie(gpt_oss_mask, gpt_oss_offset) = gpt_oss_sliding_window_pattern();
// gpt-oss and gemma3 cases
std::shared_ptr<ov::Node> gptoss_gemma3_mask, gptoss_gemma3_offset;
std::tie(gptoss_gemma3_mask, gptoss_gemma3_offset) = gptoss_gemma3_sliding_window_pattern();

// Scale's shape limitations according to SDPA specification
auto scale_predicate = [=](const Output<Node>& output) -> bool {
Expand All @@ -414,7 +424,7 @@ ov::pass::StateManagementPattern::StateManagementPattern(
general_alibi_mask,
jais_alibi_mask,
baichuan2_13b_alibi_mask,
gpt_oss_mask,
gptoss_gemma3_mask,
any_input()});

auto sdpa_with_4_inputs = wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
Expand Down Expand Up @@ -602,9 +612,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(
offset = std::make_shared<v0::Convert>(offset, element::i32);
}
sliding_window = std::make_shared<v1::Subtract>(v0::Constant::create(element::i32, Shape{}, {2}), offset);
} else if (pattern_map.count(gpt_oss_offset)) {
auto offset = pattern_map.at(gpt_oss_offset).get_node_shared_ptr();
if (pattern_map.at(gpt_oss_offset).get_partial_shape().rank() != 0) {
} else if (pattern_map.count(gptoss_gemma3_offset)) {
auto offset = pattern_map.at(gptoss_gemma3_offset).get_node_shared_ptr();
if (pattern_map.at(gptoss_gemma3_offset).get_partial_shape().rank() != 0) {
offset = std::make_shared<v15::Squeeze>(offset);
}
if (offset->get_element_type() != element::i32) {
Expand Down Expand Up @@ -737,6 +747,9 @@ ov::pass::StateManagementPattern::StateManagementPattern(
}
OPENVINO_ASSERT(pa_arguments.size() == 25);

pa_arguments.insert(pa_arguments.begin() + 25, handle_gemma3_token_type_ids(optional_model_wide_params));
OPENVINO_ASSERT(pa_arguments.size() == 26);

auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);
paged_attention->get_rt_info()[NUM_K_HEADS] = num_k_heads;
paged_attention->get_rt_info()[K_HEAD_SIZE] = k_head_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,33 +121,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
auto adaptive_rkv_diversity_block_set_indices_begins =
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::Shape{0});

auto pa = std::make_shared<op::PagedAttentionExtension>(
OutputVector{Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins});
auto pa =
std::make_shared<op::PagedAttentionExtension>(OutputVector{Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids});
pa->get_rt_info()["num_k_heads"] = numKeyHeads;
pa->get_rt_info()["k_head_size"] = keyHeadSize;
pa->get_rt_info()["num_v_heads"] = numValueHeads;
Expand All @@ -174,7 +176,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins});
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids});

if (isIRKVCacheF16) {
model->set_rt_info("f16", "runtime_options", ov::hint::kv_cache_precision.name());
Expand Down Expand Up @@ -254,33 +257,35 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
auto adaptive_rkv_diversity_block_set_indices_begins =
std::make_shared<v0::Parameter>(ov::element::i32, PartialShape{DYN});
auto token_type_ids = std::make_shared<v0::Parameter>(ov::element::i32, ov::Shape{0});

auto pa = std::make_shared<op::PagedAttentionExtension>(
OutputVector{Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins});
auto pa =
std::make_shared<op::PagedAttentionExtension>(OutputVector{Q,
K,
V,
key_cache_0,
value_cache_0,
past_lens,
subsequence_begins,
block_indices,
block_indices_begins,
scale,
sliding_window,
alibi_slopes,
max_context_len,
score_aggregation_window,
rotated_block_indices,
rotation_deltas,
rotation_trig_lut,
xattention_threshold,
xattention_block_size,
xattention_stride,
sinks,
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids});
pa->get_rt_info()["num_k_heads"] = numKeyHeads;
pa->get_rt_info()["k_head_size"] = keyHeadSize;
pa->get_rt_info()["num_v_heads"] = numValueHeads;
Expand All @@ -307,7 +312,8 @@ TEST_P(ConvertPagedAttnInputsTest, checkPrecisionAndShape) {
adaptive_rkv_start_size,
adaptive_rkv_evictable_sizes,
adaptive_rkv_diversity_block_set_indices,
adaptive_rkv_diversity_block_set_indices_begins});
adaptive_rkv_diversity_block_set_indices_begins,
token_type_ids});
}
ov::pass::ConvertPagedAttnInputs::KVCacheConfig cacheConfig;
cacheConfig.keyCacheBlockSize = blockSize[0];
Expand Down
Loading
Loading