Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@
OnnxTransformPipeline,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.blocking.blocking_configurator import build_transformer_blocking_config_for_transform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.transformers.models.pytorch_transforms import BlockingAttentionTransform
from QEfficient.utils import (
constants,
create_json,
create_model_params,
dump_qconfig,
generate_mdp_partition_config,
get_attr_or_key,
hash_dict_params,
load_json,
require_value,
)
from QEfficient.utils.export_utils import export_wrapper

Expand Down Expand Up @@ -287,7 +291,6 @@ def _export(
else:
input_names.append(param)

# import ipdb; ipdb.set_trace()
try:
torch.onnx.export(
self.model,
Expand Down Expand Up @@ -345,13 +348,15 @@ def get_onnx_path(
retain_full_kv: Optional[bool] = False,
enable_mla: Optional[bool] = False,
mla_absorption_config: Optional[bool] = False,
mdp_ts_num_devices: Optional[int] = 1,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
"enable_mla": enable_mla,
"mla_absorption_config": mla_absorption_config,
"mdp_ts_num_devices": mdp_ts_num_devices,
}

if prefill_only:
Expand All @@ -366,6 +371,45 @@ def get_onnx_path(
self.export(**kwargs)
return self.onnx_path

def transform(
self,
ctx_len: Optional[int] = None,
seq_len: Optional[int] = None,
bs: Optional[int] = 1,
num_devices: int = 1,
qaic_config: Optional[dict] = None,
**compiler_options,
):
# Apply the transformations that are dependent on compilation parameters

qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)

if getattr(self.model, "config", None) or getattr(self.model.model, "config", None):
blocking_config = build_transformer_blocking_config_for_transform(
getattr(self.model, "config", None)
if getattr(self.model, "config", None)
else getattr(self.model.model, "config", None),
ctx_len=ctx_len,
seq_len=seq_len,
bs=bs,
num_devices=num_devices,
qaic_config=qaic_config,
**compiler_options,
)
else:
# without a model config, this is not a model that is possible to block
blocking_config = None

if blocking_config is not None:
self.model, _ = BlockingAttentionTransform.apply(self.model, attn_blocking_config=blocking_config)
blocking_kwargs = self.hash_params.setdefault("blocking_kwargs", {})
if blocking_config.num_kv_blocks:
blocking_kwargs["num_kv_blocks"] = blocking_config.num_kv_blocks
if blocking_config.num_q_blocks:
blocking_kwargs["num_q_blocks"] = blocking_config.num_q_blocks
if blocking_config.head_block_size:
blocking_kwargs["head_block_size"] = blocking_config.head_block_size

@dump_qconfig
def _compile(
self,
Expand All @@ -384,6 +428,10 @@ def _compile(
offload_pt_weights: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = None,
disable_blocking: Optional[bool] = True,
blocking_mode: Optional[str] = "hqkv",
vtcm_ratio: Optional[float] = 0.75,
qaic_config: Optional[dict] = None,
enable_mla: Optional[bool] = False,
mla_absorption_config: Optional[Dict[str, bool]] = False,
**compiler_options,
Expand Down Expand Up @@ -411,6 +459,24 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""

# Transform before export
qaic_config = qaic_config if qaic_config else getattr(self.model, "qaic_config", None)
bs = require_value(get_attr_or_key(specializations[0], ("batch_size", "batch")), "batch size")
seq_len = get_attr_or_key(specializations[0], ("cl", "seq_len", "sequence_length"))
ctx_len = get_attr_or_key(specializations[0], ("ctx_len", "context_length"))
self.transform(
ctx_len=ctx_len,
seq_len=seq_len,
bs=bs,
num_devices=mdp_ts_num_devices,
disable_blocking=disable_blocking,
blocking_mode=blocking_mode,
vtcm_ratio=vtcm_ratio,
qaic_config=qaic_config,
**compiler_options,
)

onnx_path = Path(
onnx_path
if onnx_path
Expand All @@ -425,6 +491,7 @@ def _compile(
retain_full_kv,
enable_mla,
mla_absorption_config,
mdp_ts_num_devices,
)
)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
126 changes: 126 additions & 0 deletions QEfficient/blocking/attention_blocking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, Optional

import torch
from transformers.cache_utils import Cache

from QEfficient.blocking.blocked_attention_forwards import (
blocked_h_attention_forward,
blocked_hqkv_attention_forward,
blocked_kv_attention_forward,
blocked_q_attention_forward,
blocked_qkv_attention_forward,
invalid_blocking_attention_forward,
)


class BlockingMode(str, Enum):
NONE = ""
KV = "kv"
Q = "q"
H = "h"
QKV = "qkv"
HQKV = "hqkv"


@dataclass
class AttentionBlockingConfig:
mode: BlockingMode = BlockingMode.NONE
num_kv_blocks: Optional[int] = None
num_q_blocks: Optional[int] = None
head_block_size: Optional[int] = None


def supports_blocked_kv(past_key_value: Optional[Cache]) -> bool:
return past_key_value is not None and hasattr(past_key_value, "read_only_blockedKV")


_STRATEGIES: Dict[BlockingMode, Callable] = {
BlockingMode.NONE: invalid_blocking_attention_forward,
BlockingMode.KV: blocked_kv_attention_forward,
BlockingMode.Q: blocked_q_attention_forward,
BlockingMode.H: blocked_h_attention_forward,
BlockingMode.QKV: blocked_qkv_attention_forward,
BlockingMode.HQKV: blocked_hqkv_attention_forward,
}


def get_blocking_strategy(config: AttentionBlockingConfig) -> Callable:
return _STRATEGIES.get(config.mode, _STRATEGIES[BlockingMode.NONE])


def generic_blocked_attention_interface(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
layer_idx: int,
past_key_value: Cache,
blocking_config: AttentionBlockingConfig,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_seen_tokens: Optional[int] = None,
non_blocked_forward: Callable = None,
**kwargs,
):
use_kv_blocked = (
blocking_config is not None and "kv" in blocking_config.mode and supports_blocked_kv(past_key_value)
)
use_blocking = blocking_config is not None and (blocking_config.mode != BlockingMode.NONE)

if past_key_value is not None:
if use_kv_blocked:
cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"past_seen_tokens": past_seen_tokens,
}
past_key_value.write_only(key, value, module.layer_idx, cache_kwargs)
else:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key, value = past_key_value.update(key, value, module.layer_idx, cache_kwargs)

if use_blocking:
strategy = get_blocking_strategy(blocking_config)
attn_output, attn_weights = strategy(
module=module,
query=query,
key=key,
value=value,
attention_mask=attention_mask,
scaling=scaling,
cache_kwargs=cache_kwargs,
layer_idx=layer_idx,
past_key_value=past_key_value,
num_kv_blocks=blocking_config.num_kv_blocks,
num_q_blocks=blocking_config.num_q_blocks,
head_block_size=blocking_config.head_block_size,
)
else:
attn_output, attn_weights = non_blocked_forward(
module,
query,
key,
value,
attention_mask,
scaling=scaling,
**kwargs,
)

return attn_output, attn_weights
Loading
Loading