Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions modules/genai_optimizations/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ This module provides experimental optimizations for GenAI models in PyTorch. The
- **Tri-Shape Mode** – A static block-sparse attention pattern that preserves the initial tokens, local windows, and the final segment of the query, forming a triangular structure to capture critical tokens while maintaining instruction-following performance in both turn-0 and multi-request scenarios. Paper: https://arxiv.org/pdf/2412.10319
- **XAttention Mode** – A dynamic block-sparse attention mechanism that accelerates inference by focusing computation on the most important regions of the attention matrix using antidiagonal block scoring, reducing FLOPs and memory usage without significant loss of accuracy. Paper: https://arxiv.org/pdf/2503.16428

- [**KV Cache Token Eviction**](./token_eviction.py):
Designed to optimize KV cache memory usage during autoregressive generation in LLMs. It selectively removes less important cached tokens while preserving those crucial for contextual understanding, enabling efficient long-sequence inference under constrained memory. Note that currently eviction starts only after the full prompt has been processed; i.e., no eviction takes place during the prefill phase.

The KV cache is split into three parts: **start**, **intermediate (evictable)**, and **recent**. The size of each part is configurable:
- **Start Area** – Initial tokens that are never evicted.
- **Intermediate Area** – Tokens that can be evicted based on importance scores.
- **Recent Area** – Most recent tokens that are preserved (not evicted while in this area, but naturally migrate toward the evictable area as text generation continues).

Eviction granularity can be **per-token** or **per-group**:
- **Per-token** – Tokens are evicted independently from the KV cache.
- **Per-group** – Only fully filled blocks from the evictable area are removed. Tokens are managed in consecutive, non-overlapping groups, following the concept of *Paged Attention*, which organizes the KV cache into pages. Each token belongs to a single page and remains there for the entire generation process. To maximize eviction efficiency, entire pages are evicted rather than individual tokens. The `group_size` is a configurable algorithm parameter.

Supported modes:
- **H2O Mode** – Evicts tokens using the *Heavy-Hitter Oracle* strategy, which accumulates attention scores to identify and retain high-impact tokens. It also preserves recent tokens due to their strong correlation with the current context. Scores are accumulated throughout the entire generation process, and their weighting can be adjusted via the `normalize_scores` parameter, which controls whether attention scores are normalized by the number of times each token was attended to.
Paper: https://arxiv.org/pdf/2306.14048
- **SnapKV Mode** – Modifies the *H2O* approach by computing token importance within a small sliding window of the most recent queries during the prefill stage, then reverting to the H2O strategy during decoding. The authors observed that only a small subset of prompt tokens is sufficient for accurate response generation.
Paper: https://arxiv.org/pdf/2404.14469

## Supported and tested models

Large Language Models:
Expand Down
22 changes: 19 additions & 3 deletions modules/genai_optimizations/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ python longbench.py \
--subset samsum \
--model meta-llama/Llama-3.2-1B-Instruct \
--use_custom_attention \
--prefill_impl tri-shape
--prefill_impl tri-shape \
--enable_eviction \
--algorithm h2o \
--granularity per_group \
--normalize_scores \
--intermediate_tokens 1024
```
This will automatically:

- Download the selected model and dataset
- Apply sparse attention computation during the prefill stage
- Apply token eviction during the decoding stage
- Evaluate the model and report the score

</details>
Expand All @@ -46,13 +52,18 @@ python mmebench.py \
--num_keep_tokens 128 \
--theta 0.5 \
--use_custom_attention \
--prefill_impl x-attention
--prefill_impl x-attention \
--enable_eviction \
--algorithm snapkv \
--granularity per_group \
--window_size 8
```
This will automatically:

- Download the selected model and dataset
- Apply the visual token pruning algorithm
- Apply sparse attention computation during the prefill stage
- Apply token eviction during the decoding stage
- Evaluate the model and report the score

</details>
Expand All @@ -73,14 +84,19 @@ python milebench.py \
--num_keep_tokens 64 \
--theta 0.5 \
--use_custom_attention \
--prefill_impl tri-shape
--prefill_impl tri-shape \
--enable_eviction \
--algorithm snapkv \
--granularity per_group \
--window_size 8
```

This will automatically:

- Download the selected model and dataset
- Apply the visual token pruning algorithm
- Apply sparse attention computation during the prefill stage
- Apply token eviction during the decoding stage
- Evaluate the model and report the score

</details>
26 changes: 14 additions & 12 deletions modules/genai_optimizations/benchmarks/longbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

from genai_opt import SparseAttention
from utils import add_attention_args
from utils import add_attention_args, add_token_eviction_args
from utils import get_eviction_patcher, get_sparse_attention_patcher

# (Phi3 and DeepSeek issue)
# AttributeError: 'DynamicCache' object has no attribute 'get_max_length'. Did you mean: 'get_seq_length'?
Expand Down Expand Up @@ -315,28 +315,29 @@ def evaluate(args):
args.model, trust_remote_code=True, token=os.environ.get("HF_TOKEN", None)
)

kwargs = {"temperature": None, "top_p": None, "top_k": None}
# force attn_implementation="eager" when using token eviction without custom attention
if args.enable_eviction and not args.use_custom_attention:
kwargs["attn_implementation"] = "eager"

model = AutoModelForCausalLM.from_pretrained(
args.model,
# attn_implementation="eager",
trust_remote_code=True,
dtype=torch.float16,
device_map="auto",
token=os.environ.get("HF_TOKEN", None),
temperature=None,
top_p=None,
top_k=None,
**kwargs,
).eval()

patchers = []
if args.use_custom_attention:
sparse_attn = SparseAttention(
algorithm=args.prefill_impl,
threshold=args.threshold,
recent_size=args.recent_size,
last_query_size=args.last_query_size,
)
sparse_attn = get_sparse_attention_patcher(args)
patchers.append(sparse_attn)

if args.enable_eviction:
token_eviction = get_eviction_patcher(args)
patchers.append(token_eviction)

max_new_tokens = dataset.get_max_new_tokens()
answers = []
max_length = 4500
Expand Down Expand Up @@ -391,6 +392,7 @@ def evaluate(args):
parser.add_argument("--model", type=str, required=True, help="Model name")

add_attention_args(parser)
add_token_eviction_args(parser)
args = parser.parse_args()

evaluate(args)
30 changes: 17 additions & 13 deletions modules/genai_optimizations/benchmarks/milebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from transformers import AutoProcessor

from logging import getLogger
from genai_opt import get_inputs_embeds, SparseAttention
from utils import add_attention_args, add_visual_pruning_args
from genai_opt import get_inputs_embeds
from utils import add_attention_args, add_visual_pruning_args, add_token_eviction_args
from utils import get_eviction_patcher, get_sparse_attention_patcher


logger = getLogger(__name__)
Expand Down Expand Up @@ -454,21 +455,25 @@ def get_model_class(model_name):

add_visual_pruning_args(parser)
add_attention_args(parser)
add_token_eviction_args(parser)
args = parser.parse_args()

dataset = MileBenchDataset(data_dir=args.data_dir, subset=args.subset)
processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
model_cls = get_model_class(args.model)

kwargs = {"temperature": None, "top_p": None, "top_k": None}
# force attn_implementation="eager" when using token eviction without custom attention
if args.enable_eviction and not args.use_custom_attention:
kwargs["attn_implementation"] = "eager"

model = model_cls.from_pretrained(
args.model,
# attn_implementation="eager",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN", None),
temperature=None,
top_p=None,
top_k=None,
**kwargs
)
model = model.eval()

Expand All @@ -482,12 +487,11 @@ def get_model_class(model_name):

contexts = []
if args.use_custom_attention:
sparse_prefill = SparseAttention(
algorithm=args.prefill_impl,
threshold=args.threshold,
recent_size=args.recent_size,
last_query_size=args.last_query_size,
)
contexts.append(sparse_prefill)
sparse_attn = get_sparse_attention_patcher(args)
contexts.append(sparse_attn)

if args.enable_eviction:
token_eviction = get_eviction_patcher(args)
contexts.append(token_eviction)

evaluate(dataset, processor, model, num_keep_tokens=num_keep_tokens, theta=theta, contexts=contexts)
29 changes: 16 additions & 13 deletions modules/genai_optimizations/benchmarks/mmebench.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from transformers import AutoProcessor
from transformers import set_seed

from genai_opt import get_inputs_embeds, SparseAttention
from utils import add_attention_args, add_visual_pruning_args
from genai_opt import get_inputs_embeds
from utils import add_attention_args, add_visual_pruning_args, add_token_eviction_args
from utils import get_eviction_patcher, get_sparse_attention_patcher


class MetricCalculator:
Expand Down Expand Up @@ -89,16 +90,19 @@ def evaluate(args):

processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model_cls = get_model_class(model_name)

kwargs = {"temperature": None, "top_p": None, "top_k": None}
# force attn_implementation="eager" when using token eviction without custom attention
if args.enable_eviction and not args.use_custom_attention:
kwargs["attn_implementation"] = "eager"

model = model_cls.from_pretrained(
model_name,
trust_remote_code=True,
# attn_implementation="eager",
dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN", None),
temperature=None,
top_p=None,
top_k=None,
**kwargs
).eval()

if args.enable_visual_pruning:
Expand All @@ -111,15 +115,13 @@ def evaluate(args):

contexts = []
if args.use_custom_attention:
print(f"Enable custom attention kernel with {args.prefill_impl} implementation")
sparse_prefill = SparseAttention(
algorithm=args.prefill_impl,
threshold=args.threshold,
recent_size=args.recent_size,
last_query_size=args.last_query_size,
)
sparse_prefill = get_sparse_attention_patcher(args)
contexts.append(sparse_prefill)

if args.enable_eviction:
token_eviction = get_eviction_patcher(args)
contexts.append(token_eviction)

all_items = []
with ExitStack() as stack:
for ctx in contexts:
Expand Down Expand Up @@ -230,6 +232,7 @@ def get_model_class(model_name):

add_visual_pruning_args(parser)
add_attention_args(parser)
add_token_eviction_args(parser)
args = parser.parse_args()

evaluate(args)
53 changes: 52 additions & 1 deletion modules/genai_optimizations/benchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2018-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from genai_opt import SparseAttention
from genai_opt import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor

def add_visual_pruning_args(parser):
group = parser.add_argument_group("Visual Token Pruning Arguments")
Expand Down Expand Up @@ -28,3 +29,53 @@ def add_attention_args(parser):
help="Window size of recent tokens each query can attend to in the Tri-shape pattern"
)
return parser


def add_token_eviction_args(parser):
group = parser.add_argument_group("Token Eviction Arguments")
group.add_argument("--enable_eviction", action="store_true", help="Enable token eviction")
group.add_argument("--algorithm", default="snapkv", choices=["snapkv", "h2o"], help="The KV cache eviction algorithm")
group.add_argument("--granularity", default="per_group", choices=["per_token", "per_group"], help="Eviction granularity")
group.add_argument(
"--normalize_scores",
action="store_true",
help="Whether to normalize the attention scores by the number of times each token was attended to."
)
group.add_argument(
"--start_tokens",
type=int,
default=32,
help="The number of tokens in the beginning of the cache (least recent) to be retained"
)
group.add_argument("--intermediate_tokens", type=int, default=1024, help="The number of intermediate tokens to consider for eviction")
group.add_argument("--recent_tokens", type=int, default=128, help="The number of most recent tokens to be retained")
group.add_argument("--group_size", type=int, default=32, help="Group size for per-group eviction strategy")
group.add_argument("--window_size", type=int, default=None, help="The size of the importance score aggregation window")
return parser


def get_sparse_attention_patcher(args):
print(f"Enable custom attention kernel with {args.prefill_impl} implementation")
return SparseAttention(
algorithm=args.prefill_impl,
threshold=args.threshold,
recent_size=args.recent_size,
last_query_size=args.last_query_size,
output_attentions=args.enable_eviction, # output attention weights only if eviction is enabled
)


def get_eviction_patcher(args):
print(f"Enable token eviction with {args.algorithm} algorithm")
algorithm = KVCacheCompressionMode(args.algorithm)
params = KVCacheCompressionParameters(
algorithm=algorithm,
granularity=args.granularity,
group_size=args.group_size,
start_tokens=args.start_tokens,
recent_tokens=args.recent_tokens,
intermediate_tokens=args.intermediate_tokens,
normalize_scores=args.normalize_scores,
window_size=args.window_size,
)
return KVCacheCompressor(eviction_parameters=params)
1 change: 1 addition & 0 deletions modules/genai_optimizations/genai_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

from genai_opt.visual_token_pruning import get_inputs_embeds
from genai_opt.sparse_attention import SparseAttention
from genai_opt.token_eviction import KVCacheCompressionMode, KVCacheCompressionParameters, KVCacheCompressor
Loading
Loading