Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,13 @@ def _prepare_decoder_only_export_inputs(self, max_seq_len: int):
return example_inputs_embeds, example_cache_position, dynamic_shapes

def _register_custom_attention(self, exportable_module: torch.nn.Module):
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
if self.use_custom_sdpa:
if self.use_custom_kv_cache:
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_passthrough)
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
else:
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa"

def export(
Expand Down Expand Up @@ -327,26 +323,23 @@ def export(
dynamic_shapes=dynamic_shapes,
strict=True,
)
# Apply RemoveTransposes pass to remove
# any back-to-back transpose ops that are not needed
# e.g. output of update_cache is transposed and
# input to custom_sdpa is transposed.
from executorch.extension.llm.export.export_passes import (
RemoveRedundantTransposes,
)
# Remove back-to-back transpose ops introduced by custom SDPA + custom KV cache.
if self.use_custom_sdpa:
from executorch.extension.llm.export.export_passes import (
RemoveRedundantTransposes,
)

mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
exported_program = torch.export.export(
mutated_gm,
args=(),
# For the ET runner, it's important to have cache position as the 2nd arg.
kwargs={
"inputs_embeds": inputs_embeds,
"cache_position": cache_position,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
exported_program = torch.export.export(
mutated_gm,
args=(),
kwargs={
"inputs_embeds": inputs_embeds,
"cache_position": cache_position,
},
dynamic_shapes=dynamic_shapes,
strict=True,
)
exported_programs["text_decoder"] = exported_program

# 2. Export token embeddings
Expand Down
73 changes: 73 additions & 0 deletions optimum/exporters/executorch/recipes/metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import logging
import math
from typing import Dict, Union

import torch
from packaging.version import parse

from executorch import version as executorch_version
Expand All @@ -23,13 +25,53 @@
EXECUTORCH_VERSION = parse(executorch_version.__version__)
METAL_BACKEND_AVAILABLE = EXECUTORCH_VERSION >= parse("1.1.0.dev20251017")

METAL_SUPPORTED_HEAD_DIMS = (64, 96, 128)

if METAL_BACKEND_AVAILABLE:
try:
from executorch.backends.apple.metal.metal_backend import MetalBackend
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
except ImportError:
METAL_BACKEND_AVAILABLE = False


def _sdpa_decomposition(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False):
"""Decompose scaled_dot_product_attention into matmul + softmax.

The Metal SDPA kernel only supports head_dim in {64, 96, 128}.
For models with other head dimensions (e.g. Gemma3 with head_dim=256),
we decompose into ops that AOTI can compile for Metal: matmul, softmax, etc.
"""
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_weight = torch.ops.aten.matmul.default(
query, torch.ops.aten.transpose.int(key, -2, -1)
)
attn_weight = torch.ops.aten.mul.Scalar(attn_weight, scale_factor)
if is_causal:
causal_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril()
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
attn_bias = torch.ops.aten.masked_fill.Scalar(attn_bias, ~causal_mask, float("-inf"))
attn_weight = torch.ops.aten.add.Tensor(attn_weight, attn_bias)
if attn_mask is not None:
attn_weight = torch.ops.aten.add.Tensor(attn_weight, attn_mask)
attn_weight = torch.ops.aten.softmax.int(attn_weight, -1)
return torch.ops.aten.matmul.default(attn_weight, value)


def _linear_bias_decomposition(input, weight, bias=None):
"""Decompose linear with bias into matmul + add.

Avoids Metal backend issues with reinterpret_tensor_wrapper when
linear layers have biases (0-stride problem).
"""
weight_t = torch.ops.aten.t.default(weight)
out = torch.ops.aten.matmul.default(input, weight_t)
if bias is not None:
return torch.ops.aten.add.Tensor(out, bias)
return out


if METAL_BACKEND_AVAILABLE:
from tabulate import tabulate
from torch.export import ExportedProgram
Expand Down Expand Up @@ -124,6 +166,37 @@ def _lower_to_executorch(
):
raise NotImplementedError("Custom SDPA implementation is not supported for Metal.")

# Metal uses standard SDPA, not custom SDPA with custom KV cache.
if hasattr(model, "use_custom_sdpa"):
model.use_custom_sdpa = False
if hasattr(model, "use_custom_kv_cache"):
model.use_custom_kv_cache = False

exported_progs = model.export()

# Decompose ops that the Metal backend cannot handle natively.
decomp_table = {
torch.ops.aten.linear.default: _linear_bias_decomposition,
}

# The Metal SDPA kernel only supports head_dim in {64, 96, 128}.
# For models with unsupported head_dim, decompose SDPA into matmul + softmax.
head_dim = getattr(model.config, "head_dim", None)
if head_dim is None:
text_config = getattr(model.config, "text_config", None)
if text_config is not None:
head_dim = getattr(text_config, "head_dim", None)

if head_dim is not None and head_dim not in METAL_SUPPORTED_HEAD_DIMS:
logging.info(
f"Model head_dim={head_dim} is not natively supported by Metal SDPA kernel "
f"(supported: {METAL_SUPPORTED_HEAD_DIMS}). Decomposing SDPA into matmul + softmax."
)
decomp_table[torch.ops.aten.scaled_dot_product_attention.default] = _sdpa_decomposition

exported_progs = {
key: ep.run_decompositions(decomp_table)
for key, ep in exported_progs.items()
}

return _lower_to_executorch(exported_progs, model.metadata)
Loading