Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
92fbee3
feat: Add KV blocking attention optimization for Causal LMs
vbaddi Jan 18, 2026
25450eb
nit: fix reducesum for kv blocking w/sub and hash params to handle qa…
vbaddi Jan 18, 2026
9bb7c4d
nit: rename BlockedKVAttentionTransform to KVBlockingAttentionTransform
vbaddi Jan 18, 2026
56cfcf4
nit: add kv+q blocking generalize support for qwen3/qwen3_moe/mllama
vbaddi Jan 18, 2026
406706f
Initial changes for moving transforms out of pretrained
kdulla Feb 10, 2026
dbb82c0
fixed CB issue with skip futures and added additional support for dif…
kdulla Feb 12, 2026
64796d3
added additional support for gpt-oss and qwen2.5vl
kdulla Feb 13, 2026
b5b93b7
Added support for num kv blocks that do not exactly divide CL
kdulla Feb 24, 2026
1538d48
Fixed blocking configurator
kdulla Feb 25, 2026
c345b68
Fixing formatting
kdulla Feb 25, 2026
6c142ab
Further formatting corrections
kdulla Feb 26, 2026
ab6069f
Test formatting error
kdulla Feb 26, 2026
d911e25
restructured some file paths and added optional qaic_config override …
kdulla Mar 5, 2026
17ce9f6
minor formatting fix
kdulla Mar 5, 2026
0823a5b
Restructured blocking changes to have generic blocked attention inter…
kdulla Mar 9, 2026
ccedb77
Removed print statement
kdulla Mar 10, 2026
3fba5b4
added in progress qwen2.5 vl support and mistral support
kdulla Mar 11, 2026
a665a20
fixed minor formatting issues
kdulla Mar 17, 2026
66cba12
fixing generic blocked attention based on code reviews
kdulla Mar 31, 2026
ca7dae9
removed incomplete blocking code from untested models, added generic …
kdulla Apr 2, 2026
63a635a
fixed formatting and a few incorrect merges
kdulla Apr 2, 2026
d14dd52
additional merge errors
kdulla Apr 2, 2026
e409baa
added batch blocking
kdulla Apr 6, 2026
c9fa160
fixed past key value merge error
kdulla Apr 6, 2026
bcd263a
fixed blocking test config
kdulla Apr 6, 2026
6ff0bbe
added a fix for batch blocking not exporting correctly
kdulla Apr 7, 2026
795806e
fixed non-dynamic kv blocking
kdulla Apr 8, 2026
cdd9823
Added Qwen3VL blocking support
kdulla Apr 8, 2026
98ecfd1
Fixes to qwen3 automatic blocking
kdulla Apr 9, 2026
aa328f9
added missing init file
kdulla Apr 9, 2026
0f673cb
fix for qwen2.5vl
kdulla Apr 10, 2026
96275b3
fixed mixtral modeling and removed unnecessary transform if we alread…
kdulla Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,22 @@
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.models.pytorch_transforms import (
BlockingAttentionTransform,
)
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
generate_mdp_partition_config,
get_attr_or_key,
hash_dict_params,
load_json,
require_value,
)
from QEfficient.utils.export_utils import export_wrapper

Expand Down Expand Up @@ -328,6 +334,8 @@ def get_onnx_path(
offload_pt_weights: Optional[bool] = True,
use_onnx_subfunctions: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
qaic_config: Optional[dict] = None,
**compiler_options,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
Expand All @@ -344,9 +352,63 @@ def get_onnx_path(
}
)

# Transform before export
qaic_config = (
qaic_config if qaic_config else getattr(self.model, "qaic_config", None) if hasattr(self, "model") else None
)
if specializations is not None:
bs = require_value(get_attr_or_key(specializations[0], ("batch_size", "batch")), "batch size")
seq_len = get_attr_or_key(specializations[0], ("cl", "seq_len", "sequence_length"))
ctx_len = get_attr_or_key(specializations[0], ("ctx_len", "context_length"))
else:
bs = None
seq_len = None
ctx_len = None

self.transform(
ctx_len=ctx_len,
seq_len=seq_len,
bs=bs,
qaic_config=qaic_config,
**compiler_options,
)

self.export(**kwargs)
return self.onnx_path

def transform(
self,
ctx_len: Optional[int] = None,
seq_len: Optional[int] = None,
bs: Optional[int] = 1,
num_devices: int = 1,
qaic_config: Optional[dict] = None,
**compiler_options,
):
# Apply the transformations that are dependent on compilation parameters

qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)

if getattr(self.model, "config", None) or getattr(self.model.model, "config", None):
blocking_config = build_transformer_blocking_config_for_transform(
getattr(self.model, "config", None)
if getattr(self.model, "config", None)
else getattr(self.model.model, "config", None),
ctx_len=ctx_len,
seq_len=seq_len,
bs=bs,
num_devices=num_devices,
qaic_config=qaic_config,
**compiler_options,
)
else:
# without a model config, this is not a model that is possible to block
blocking_config = None

if blocking_config is not None:
self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config)
self.hash_params["blocking_kwargs"] = blocking_config

@dump_qconfig
def _compile(
self,
Expand All @@ -365,6 +427,7 @@ def _compile(
offload_pt_weights: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = None,
qaic_config: Optional[dict] = None,
**compiler_options,
) -> str:
"""
Expand Down Expand Up @@ -402,6 +465,9 @@ def _compile(
offload_pt_weights,
use_onnx_subfunctions,
retain_full_kv,
num_devices=mdp_ts_num_devices,
qaic_config=qaic_config,
**compiler_options,
)
)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/blocking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
166 changes: 166 additions & 0 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, Optional

import torch
from transformers.cache_utils import Cache

from QEfficient.blocking.blocked_attention_forwards import (
blocked_bhqkv_attention_forward,
blocked_h_attention_forward,
blocked_hqkv_attention_forward,
blocked_kv_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
)


class BlockingMode(str, Enum):
NONE = ""
KV = "kv"
Q = "q"
H = "h"
QKV = "qkv"
HQKV = "hqkv"
BHQKV = "bhqkv"


@dataclass
class AttentionBlockingConfig:
mode: BlockingMode = BlockingMode.NONE
num_kv_blocks: Optional[int] = None
num_q_blocks: Optional[int] = None
head_block_size: Optional[int] = None
skip_kv: Optional[bool] = False
num_batch_blocks: Optional[int] = None


def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV")


_STRATEGIES: Dict[BlockingMode, Callable] = {
BlockingMode.KV: blocked_kv_attention_forward,
BlockingMode.Q: blocked_q_attention_forward,
BlockingMode.H: blocked_h_attention_forward,
BlockingMode.QKV: blocked_qkv_attention_forward,
BlockingMode.HQKV: blocked_hqkv_attention_forward,
BlockingMode.BHQKV: blocked_bhqkv_attention_forward,
}


def get_blocking_strategy(config: AttentionBlockingConfig) -> Callable:
return _STRATEGIES.get(config.mode)


# helper function needed both in generic blocked approach and in other modeling files for non-blocked approach
def past_key_value_update(
module,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Cache,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
sliding_window: Optional[int] = None,
):
if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if sliding_window is not None:
cache_kwargs.update(
{
"is_sliding": sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
)
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key, value = past_key_value.update(key, value, module.layer_idx, cache_kwargs)
return key, value, cache_kwargs


def generic_blocked_attention_interface(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
layer_idx: int,
past_key_value: Cache,
blocking_config: AttentionBlockingConfig,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_seen_tokens: Optional[int] = None,
non_blocked_forward: Callable = None,
score_mod: Optional[Callable] = None,
position_bias: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
**kwargs,
):
use_kv_blocked = (
blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value)
)

if past_key_value is not None:
if use_kv_blocked and sliding_window is None:
cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"past_seen_tokens": past_seen_tokens,
}
if sliding_window is not None:
cache_kwargs.update(
{
"is_sliding": sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
)
past_key_value.write_only(key, value, module.layer_idx, cache_kwargs)
else:
key, value, cache_kwargs = past_key_value_update(
module=module,
key=key,
value=value,
attention_mask=attention_mask,
past_key_value=past_key_value,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
position_ids=position_ids,
sliding_window=sliding_window,
)

strategy = _STRATEGIES.get(blocking_config.mode)
attn_output, attn_weights = strategy(
module=module,
query=query,
key=key,
value=value,
attention_mask=attention_mask,
scaling=scaling,
cache_kwargs=cache_kwargs,
layer_idx=layer_idx,
past_key_value=past_key_value,
num_kv_blocks=blocking_config.num_kv_blocks,
num_q_blocks=blocking_config.num_q_blocks,
head_block_size=blocking_config.head_block_size,
num_batch_blocks=blocking_config.num_batch_blocks,
score_mod=score_mod,
position_bias=position_bias,
sinks=sinks,
)

return attn_output, attn_weights
Loading
Loading