Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/test_ipex.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ jobs:
strategy:
fail-fast: false
matrix:
transformers-version: ["4.51.0", "4.52.*"]
torch-version: ["2.7.0"]
test-file:
[test_modeling.py, test_pipelines.py, test_modeling_causal_lm.py]
transformers-version: ["4.55.4"]
torch-version: ["2.8.0"]
test-file: [test_modeling.py, test_pipelines.py, test_modeling_causal_lm.py]

runs-on: ubuntu-22.04

Expand Down
186 changes: 145 additions & 41 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,110 @@
import os
from typing import List, Optional, Tuple
from typing import Any, Optional, Tuple

import intel_extension_for_pytorch as ipex
import torch
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig

from optimum.intel.utils.import_utils import is_ipex_version

try:
from transformers.cache_utils import CacheLayerMixin
except ImportError:
# Fallback for older transformers versions
class CacheLayerMixin:
"""Minimal fallback for CacheLayerMixin when not available in transformers."""

pass


class IPEXCacheLayer(CacheLayerMixin):
"""
A cache layer for IPEX PagedAttention that stores key and value states
as paged tensors optimized for Intel XPU and CPU devices.
"""

is_compileable = True
is_sliding = False

def __init__(
self,
key_cache_shape: tuple = None,
value_cache_shape: tuple = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
supports_flash_decoding: bool = False,
**kwargs,
):
super().__init__()
# Create cache tensors if shapes are provided, otherwise use pre-created tensors
if key_cache_shape is not None and value_cache_shape is not None:
self.keys = torch.zeros(key_cache_shape, dtype=dtype, device=device)
self.values = torch.zeros(value_cache_shape, dtype=dtype, device=device)
else:
# Fallback for pre-created tensors
self.keys = kwargs.get("key_cache", None)
self.values = kwargs.get("value_cache", None)

self.device = device
self._supports_flash_decoding = supports_flash_decoding

# Mark tensors as static for torch.compile
if self.keys is not None:
torch._dynamo.mark_static_address(self.keys)
if self.values is not None:
torch._dynamo.mark_static_address(self.values)

# Set max_batch_size and max_cache_len properties required by parent
# Extract from tensor shape if available
if self.keys is not None:
self.max_batch_size = kwargs.get("max_batch_size", 1)
self.max_cache_len = kwargs.get("max_cache_len", self.keys.shape[0])
else:
self.max_batch_size = kwargs.get("max_batch_size", 1)
self.max_cache_len = kwargs.get("max_cache_len", 1)

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Update method will be called by the paged cache's reshape_and_cache method."""
# For IPEXPagedCache, the actual update is handled by reshape_and_cache
# This method just returns the current cache tensors
return self.keys, self.values

def get_seq_length(self, cache_position=None) -> int:
"""Returns the sequence length. For paged cache, this is managed by the parent."""
# This will be overridden by the parent IPEXPagedCache
return 0

def get_max_cache_shape(self) -> int:
"""Returns the maximum cache shape."""
return self.max_cache_len

def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache for attention mask generation."""
# For paged attention, this is handled differently
kv_offset = 0
kv_length = self.get_max_cache_shape()
return kv_length, kv_offset

def reset(self) -> None:
"""Reset cache values while preserving objects."""
if self.keys is not None:
self.keys.zero_()
if self.values is not None:
self.values.zero_()

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder cache for beam search."""
if self.keys is not None and self.keys.numel():
device = self.keys.device
self.keys = self.keys.index_select(0, beam_idx.to(device))
if self.values is not None and self.values.numel():
device = self.values.device
self.values = self.values.index_select(0, beam_idx.to(device))


class IPEXPagedCache(Cache):
Expand Down Expand Up @@ -43,28 +141,12 @@ def __init__(
dtype=None,
**kwargs,
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
# Prepare parameters before calling super().__init__
default_device = torch.device("xpu") if ipex._C._has_xpu() else torch.device("cpu")
device = device or default_device
self.device = device
self._supports_flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
)
# Used in `generate` to keep tally of how many tokens the cache has seen

self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
self.slots = torch.zeros([max_cache_len * max_batch_size], dtype=torch.int32, device=device)
torch._dynamo.mark_static_address(self._seen_tokens)
torch._dynamo.mark_static_address(self.slots)
default_block_size = 16 if max_cache_len <= 64 else 64
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
max_batch_size, -1
)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
self.max_cache_len = max_cache_len
# Set up cache-related attributes
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
if getattr(config, "head_dim", None) is not None:
Expand All @@ -73,26 +155,49 @@ def __init__(
head_size = config.hidden_size // config.num_attention_heads
self.head_size = head_size

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Calculate block parameters
default_block_size = 16 if max_cache_len <= 64 else 64
self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size)))
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * max_batch_size

# Determine cache tensor shapes
if device.type == "cpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
if self._supports_flash_decoding:
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
else:
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)

# Store custom parameters for later use
self._custom_layer_kwargs = {
"key_cache_shape": key_cache_shape,
"value_cache_shape": value_cache_shape,
"device": device,
"dtype": dtype,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
}

# Now call parent without our custom parameters
super().__init__(config=config, layer_classes=IPEXCacheLayer, batch_size=max_batch_size, **kwargs)

# Update layer_init_kwargs after parent initialization
self.layer_init_kwargs.update(self._custom_layer_kwargs)

# Clear existing layers and recreate with correct parameters
self.layers.clear()
self.append_new_layers(self.num_hidden_layers - 1)

# Initialize other attributes
self._seen_tokens = torch.zeros([max_batch_size], dtype=torch.int32, device=device)
self.slots = torch.zeros([max_cache_len * max_batch_size], dtype=torch.int32, device=device)
torch._dynamo.mark_static_address(self._seen_tokens)
torch._dynamo.mark_static_address(self.slots)

self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
max_batch_size, -1
)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)

def reshape_and_cache(
self,
Expand All @@ -103,8 +208,7 @@ def reshape_and_cache(
slots: torch.Tensor,
):
# TODO: unify API definition between CPU and XPU in IPEX version > 2.6
if self.device.type == "xpu" and self._supports_flash_decoding:
# make a WA here as slots here is padded but XPU does not support slots with length not equal to key length, will fix it in IPEX 2.8
if self.device.type == "xpu":
valid_len = key.shape[0]
truncated_slots = slots[:valid_len]
PagedAttention.reshape_and_cache_flash(
Expand Down Expand Up @@ -183,10 +287,10 @@ def update(
"""

self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
key_states, value_states, self.layers[layer_idx].keys, self.layers[layer_idx].values, self.slots
)

return self.key_cache[layer_idx], self.value_cache[layer_idx]
return self.layers[layer_idx].keys, self.layers[layer_idx].values

def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
Expand Down Expand Up @@ -215,10 +319,10 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_table[i] = self.block_tables[i][nb - 1]
for layer_idx in range(self.num_hidden_layers):
# The updated_table cannot contain the whole block table, otherwise will cause core-dump.
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx].index_select(
self.layers[layer_idx].keys[updated_table] = self.layers[layer_idx].keys.index_select(
0, updated_table[beam_idx]
)
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx].index_select(
self.layers[layer_idx].values[updated_table] = self.layers[layer_idx].values.index_select(
0, updated_table[beam_idx]
)

Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@


# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.51.0"
_TRANSFORMERS_MAX_VERSION = "4.52.99"
_TRANSFORMERS_MIN_VERSION = "4.55.0"
_TRANSFORMERS_MAX_VERSION = "4.55.99"

_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)

Expand Down
21 changes: 6 additions & 15 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,6 @@ def _falcon_model_forward(
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

Expand Down Expand Up @@ -662,10 +655,6 @@ def _qwen2_model_forward(
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

Expand Down Expand Up @@ -704,8 +693,9 @@ def _qwen2_model_forward(
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))

if past_key_values is None:
attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
past_key_values_length = 0
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, inputs_embeds.shape[:2], inputs_embeds, past_key_values_length
)

# decoder layers
Expand Down Expand Up @@ -817,8 +807,9 @@ def _mistral_model_forward(
sin = sin.reshape(-1, sin.shape[-1])
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
if past_key_values is None:
attention_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
past_key_values_length = 0
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask, inputs_embeds.shape[:2], inputs_embeds, past_key_values_length
)

# decoder layers
Expand Down
Loading