Skip to content

[Executorch][llm] Enable local global attention in export_llama script #10612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/kimishpatel/189/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from .source_transformation.custom_kv_cache import (
replace_kv_cache_with_custom_kv_cache,
replace_kv_cache_with_quantized_kv_cache,
replace_kv_cache_with_ring_kv_cache,
)

from .source_transformation.quantize import (
Expand Down Expand Up @@ -153,6 +154,23 @@ def build_model(
return export_llama(args)


def parse_list_of_ints(s):
import ast

try:
parsed = ast.literal_eval(s)
if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed):
print(parsed)
return parsed
raise argparse.ArgumentTypeError(
"Must be a list of integers, e.g., [0, 16, 0, 16]"
)
except Exception:
raise argparse.ArgumentTypeError(
"Must be a list of integers, e.g., [0, 16, 0, 16]"
)


def build_args_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
Expand Down Expand Up @@ -363,6 +381,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="maximum length of context for model to remember",
)

parser.add_argument(
"--local_global_attention",
type=parse_list_of_ints,
default=None,
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].",
)

parser.add_argument("-2", "--fairseq2", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -1307,6 +1332,14 @@ def _get_source_transforms( # noqa
if args.vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

if args.local_global_attention:
transforms.append(
partial(
replace_kv_cache_with_ring_kv_cache,
layer_sizes=args.local_global_attention,
)
)

return transforms


Expand Down
15 changes: 14 additions & 1 deletion examples/models/llama/source_transformation/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
# This is needed to ensure that custom ops are registered
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401

assert len(module.layers) > len(
layer_sizes
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
multiplier = len(module.layers) // len(layer_sizes)
modulo = len(module.layers) % len(layer_sizes)
assert (
modulo == 0
), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}"
layer_sizes = layer_sizes * multiplier
logging.info(
"Replacing kv cache with ring kv cache. This modifies the model in place."
f"Applying local sliding window attention with following pattern {layer_sizes}."
)
assert len(layer_sizes) == len(
module.layers
Expand All @@ -534,4 +543,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
attention = transformer_block.attention
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
# if attention's sdpa is custom sdpa then we have to make sure
# it is not doing causal attention
if "SDPACustom" in attention.SDPA.__class__.__name__:
attention.SDPA.use_attention_mask = True
return module
Loading