Skip to content

[PagedAttention] Add bidirectional attention mask within image groups#34111

Merged
mlukasze merged 31 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_idea_2
Mar 12, 2026
Merged

[PagedAttention] Add bidirectional attention mask within image groups#34111
mlukasze merged 31 commits intoopenvinotoolkit:masterfrom
p-wysocki:attn_idea_2

Conversation

@p-wysocki
Copy link
Contributor

@p-wysocki p-wysocki commented Feb 13, 2026

Details:

  • Enable bidirectional attention within image groups for Gemma3, where tokens can attend to future tokens if they describe the same image group. For example on the left there's full attention and on the right there's sliding window attention, where image is described using 5 image tokens
image
  • Implementation is based on how transformers implements this functionality, which uses a tokenizer output token_type_ids, which classifies tokens (image vs text).:

https://github.com/huggingface/transformers/blob/a6ef2a6f3549dd3267e8f5bafe4976a3217784bb/src/transformers/models/gemma3/modular_gemma3.py#L783-L796

  • This PR allows PagedAttention to take the graph input token_type_ids and calculate the bidirectional image attention, modifying causal mask behavior. The behavior is unchanged when token_type_ids is not given.

Tests:

Questions/TODO:

  • Do other tokenizers implement token_type_ids in a different way? Gemma3 docs don't seem exactly aligned with what the tokenizer implements.
    • tokenizer output: 0 if token is a text token, 1 if token is an image token
    • documentation: it says that it distinguishes sentences, which I suppose is different image vs text.
  • Should we keep the mask generation inside PagedAttention? SDPA has it the attention mask as an optional input, which produces causal mask by default. Perhaps we should construct attention masks outside of PagedAttention and feed them to the operator? This is especially important given the fact that we're also adding tree mask.

Tickets:

  • 171180

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
@p-wysocki p-wysocki requested review from a team as code owners February 13, 2026 11:49
@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations labels Feb 13, 2026
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
…nto attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
@p-wysocki p-wysocki requested review from a team as code owners February 25, 2026 11:27
@github-actions github-actions bot added the category: GPU OpenVINO GPU plugin label Feb 25, 2026
@mryzhov mryzhov self-requested a review February 26, 2026 08:31
ov::element::i32,
getInputShapeAtPort(PagedAttentionExecutor::ID_ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES_BEGINS)));

// token_type_ids, i32, [B_token | 0] or [1, B_token]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a common interface of PA, could you encode ids inputs and decode it during execution?

For example, using [3,9] to replace [0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,...]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants @Wovchena I think we should make a final decision on this change, since it would be a significant refactor of the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been discussed that token_type_ids input is already in the model, so connecting it directly to the PA (as proposed in the PR) simplify the whole logic.

Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on enabling token types functionality.
Let's minimize impact on existing PA logic and ensure that supported models are not affected.

Comment on lines +499 to +502
size_t get_ncausal(size_t q_global_idx, size_t default_ncausal, size_t cur_kv_len) const {
if (!_token_type || q_global_idx >= _image_group_end.size()) {
return default_ncausal;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This initial check is performed for every call of get_ncausal, with the proposed common updates it will be called for every index calculation even if the "token_type_ids" input is not provided.
Consider optimizing it to not affect existing logic.

Perf validation in needed to ensure supported models are not affected by those PA changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the job easier by providing a false check which should be quickly evaluated by CPU prediction logic, but one way or another, performance testing needs to be done. Is there a standard procedure for testing PA performance? cc @maxnick

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there were no perf regression detected (shared in the ticket), I don't put it as a blocker.

But still the same check is repeated for each loop step, while the state is not changing and depends on the provided inputs - remaining the same for the same compiled model.
Approach with function ptr could be investigated as one of possible options here to make the decision once per whole kernel logic, not at each element.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 5 comments.

}

memory::ptr get_token_type_ids_memory() {
std::vector<int> token_type_ids = { 0 };
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_type_ids is created as a 1-element tensor ({0}), but B_token in this test is often >1 (e.g., subsequence {10,0}), so this input is shape-inconsistent with the per-token semantics. To keep existing tests behavior unchanged, consider passing an empty tensor (shape {0}) to represent “not provided”, or otherwise generate token_type_ids with length equal to total query tokens (sum of subsequence_desc.num_tokens) and make its layout dynamic like other per-token inputs.

Suggested change
std::vector<int> token_type_ids = { 0 };
// Represent token_type_ids as "not provided" by using an empty tensor (shape {0})
std::vector<int> token_type_ids;

Copilot uses AI. Check for mistakes.
auto adaptive_rkv_evictable_sizes_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx};
auto adaptive_rkv_diversity_block_set_indices_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx};
auto adaptive_rkv_diversity_block_set_indices_begins_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx};
auto token_type_ids_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_type_ids_layout is hard-coded as PartialShape{1}. If this input is intended to be optional (empty shape) or to match the per-token dimension, using a fixed {1} shape can mask shape issues and differs from other dynamic per-token inputs. Consider using {0} for “not provided” or {-1} (and/or {1, -1}) if the test intends to model real token_type_ids.

Suggested change
auto token_type_ids_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
auto token_type_ids_layout = layout{ov::PartialShape{-1}, data_types::i32, format::bfyx};

Copilot uses AI. Check for mistakes.
Comment on lines +24 to 26
validate_inputs_count(op, {26});
auto inputs = p.GetInputInfo(op);
auto prim = cldnn::paged_attention(layer_type_name_ID(op), inputs);
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change updates the op to require 26 inputs, but the Intel GPU implementation does not appear to consume token_type_ids anywhere (no references to TOKEN_TYPE_IDS / token_type_ids under src/plugins/intel_gpu/src). If the PR’s goal is to enable bidirectional attention on GPU, the kernels/impl need to incorporate this new input; otherwise GPU behavior will remain unchanged despite the new required input.

Copilot uses AI. Check for mistakes.
get_input_size() == 25,
"PagedAttensionExtension expects 25 inputs, but it has ",
get_input_size() == 26,
"PagedAttensionExtension expects 26 inputs, but it has ",
Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the validation error message: PagedAttensionExtensionPagedAttentionExtension. Since this string was touched in this change, it’s a good opportunity to fix it for clearer diagnostics.

Suggested change
"PagedAttensionExtension expects 26 inputs, but it has ",
"PagedAttentionExtension expects 26 inputs, but it has ",

Copilot uses AI. Check for mistakes.
Comment on lines +219 to +221
ov::Tensor token_type_tensor(ov::element::i32, {seq_len});
std::memcpy(token_type_tensor.data<int32_t>(), token_types.data(), seq_len * sizeof(int32_t));

Copy link

Copilot AI Mar 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::memcpy is used here but the file doesn’t include <cstring>. Please add the missing standard header to avoid relying on transitive includes (which can break with different toolchains / build flags).

Copilot uses AI. Check for mistakes.
@p-wysocki p-wysocki added this to the 2026.1 milestone Mar 6, 2026
…into attn_idea_2

Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com>
Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditional approval as confirmed to resolve the reported issue on the target model without detected regressions.

Refactor of the PA and growing number of inputs is a topic to be continued, but shouldn't be a blocker for all features development.

@p-wysocki Please contribute to the PagedAttention specification PR describing the token_type_ids input as suggested change:

Comment on lines +472 to +482
const PlainTensor& subsequence_begins,
const PlainTensor& past_lens) {
_has_image_tokens = true;
_token_type = token_type;
auto total_tokens = static_cast<int32_t>(token_type.m_dims[0]);
_image_group_end.resize(total_tokens);

auto seq_count = static_cast<int32_t>(past_lens.m_dims[0]);
for (int32_t seq = 0; seq < seq_count; seq++) {
auto seq_begin = subsequence_begins.ptr<int32_t>()[seq];
auto seq_end = subsequence_begins.ptr<int32_t>()[seq + 1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is subsequence_begins size ensuerd at this point?

It is accessed at [seq] and [seq + 1] here, so in final loop at subsequence_begins[seq_count], where seq_count = past_lens.m_dims[0].
It means subsequence_begins.size() <= (seq_count + 1) must be ensured.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape is validated here:

subsequence_begins.assert_dims({B_seq + 1});

Comment on lines +499 to +502
size_t get_ncausal(size_t q_global_idx, size_t default_ncausal, size_t cur_kv_len) const {
if (!_token_type || q_global_idx >= _image_group_end.size()) {
return default_ncausal;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there were no perf regression detected (shared in the ticket), I don't put it as a blocker.

But still the same check is repeated for each loop step, while the state is not changing and depends on the provided inputs - remaining the same for the same compiled model.
Approach with function ptr could be investigated as one of possible options here to make the decision once per whole kernel logic, not at each element.

for (size_t m = q_start; m < q_end; m++) {
// apply attention mask & sofmax
auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1);
auto ncausal = get_ncausal(q_token_start + m, cur_kv_len - q_cnt + (m - q_start) + 1, cur_kv_len);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to my previous comment

As there were no perf regression detected at this point (shared in the ticket), I don't put it as a blocker.

But the first argument of get_ncausal is used only when the new token_type_ids input is provided, while q_token_start + m calculated for at each loop step even if not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss offline

const auto adaptive_rkv_diversity_block_set_indices_begins =
std::make_shared<op::v0::Parameter>(element::i32, PartialShape{5});

const auto token_type_ids = std::make_shared<op::v0::Parameter>(ov::element::i32, ov::Shape{0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only existing type prop tests cases have been updated - presenting dummy token_type_ids input with empty shape. Would be good to add also tests with examples of valid input and shapes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new tests in #34661

ov::element::i32,
getInputShapeAtPort(PagedAttentionExecutor::ID_ADAPTIVE_RKV_DIVERSITY_BLOCK_SET_INDICES_BEGINS)));

// token_type_ids, i32, [B_token | 0] or [1, B_token]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been discussed that token_type_ids input is already in the model, so connecting it directly to the PA (as proposed in the PR) simplify the whole logic.

@mlukasze mlukasze added this pull request to the merge queue Mar 12, 2026
Merged via the queue into openvinotoolkit:master with commit c694fbc Mar 12, 2026
217 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: CPU OpenVINO CPU plugin category: GPU OpenVINO GPU plugin category: transformations OpenVINO Runtime library - Transformations Code Freeze

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: gemma-3-12b-it-int8-ov generates gibberish when extracting text from an image

6 participants