Skip to content

Use llm_config instead of args in export_llama functions #11162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: gh/jackzhxng/14/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backends/arm/test/models/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
TosaPipelineMI,
)

from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
Expand Down Expand Up @@ -89,8 +92,9 @@ def prepare_model(self):
]
parser = build_args_parser()
args = parser.parse_args(args)
llm_config = convert_args_to_llm_config(args)

llama_model, llama_inputs, llama_meta = get_llama_model(args)
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)

return llama_model, llama_inputs, llama_meta

Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ runtime.python_library(
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models/llama:llama_transformer",
"//executorch/examples/models/llama/config:llm_config",
"//executorch/examples/models:checkpoint",
],
)
Expand Down Expand Up @@ -266,6 +267,7 @@ runtime.python_library(
":export_library",
"//executorch/examples/models/llama/config:llm_config",
"fbsource//third-party/pypi/hydra-core:hydra-core",
"fbsource//third-party/pypi/omegaconf:omegaconf",
],
)

Expand Down
66 changes: 45 additions & 21 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _model_call(self, inps):
def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
llm_config=None,
):
"""
Generates a wrapper interface around the provided model and tokenizer for
Expand All @@ -172,7 +173,15 @@ def gen_eval_wrapper(
Returns:
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
"""
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
# If llm_config is not provided, convert args to llm_config
if llm_config is None:
from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)

llm_config = convert_args_to_llm_config(args)

tokenizer = get_tokenizer(llm_config.base.tokenizer_path)

# ExecuTorch Binary Evaluation
if (model := args.pte) is not None: # pyre-ignore
Expand All @@ -182,7 +191,7 @@ def gen_eval_wrapper(
model=model,
tokenizer=tokenizer,
tokenizer_bin=tokenizer_bin,
max_seq_length=args.max_seq_length, # pyre-ignore
max_seq_length=llm_config.export.max_seq_length,
)

# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
Expand All @@ -191,12 +200,14 @@ def gen_eval_wrapper(
tokenizer=tokenizer,
# Exported model takes at most (max_seq_length - 1) tokens.
# Note that the eager model takes at most max_seq_length tokens.
max_seq_length=args.max_seq_length - 1,
max_seq_length=llm_config.export.max_seq_length - 1,
)

pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
llm_config
)
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
manager: LLMEdgeManager = _prepare_for_llama_export(args)
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)

if len(quantizers) != 0:
manager = manager.export().pt2e_quantize(quantizers)
Expand All @@ -208,9 +219,9 @@ def gen_eval_wrapper(
return GraphModuleEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache, # pyre-ignore
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
max_seq_length=llm_config.export.max_seq_length,
use_kv_cache=llm_config.model.use_kv_cache,
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
)
else:
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
Expand All @@ -234,8 +245,8 @@ def gen_eval_wrapper(
return EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
use_kv_cache=args.use_kv_cache,
max_seq_length=llm_config.export.max_seq_length,
use_kv_cache=llm_config.model.use_kv_cache,
)


Expand Down Expand Up @@ -296,12 +307,18 @@ def eval_llama(
model_name: str,
args: argparse.ArgumentParser,
) -> None:
# Convert args to LlmConfig
from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)

llm_config = convert_args_to_llm_config(args)

# Generate the eval wrapper
eval_wrapper = gen_eval_wrapper(model_name, args)
eval_wrapper = gen_eval_wrapper(model_name, args, llm_config)

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
if args.tasks and "mmlu" in args.tasks:
import datasets

Expand All @@ -312,8 +329,8 @@ def eval_llama(
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
num_fewshot=args.num_fewshot,
limit=args.limit,
)

for task, res in eval_results["results"].items():
Expand All @@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse

This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
"""
assert args.use_attention_sink is not None # pyre-ignore [16]
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
attention_sink_params = args.use_attention_sink.split(",")
# Convert args to LlmConfig
from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)

llm_config = convert_args_to_llm_config(args)

assert llm_config.model.use_attention_sink is not None
assert args.attention_sink_eval_tokens > 0
attention_sink_params = llm_config.model.use_attention_sink.split(",")
assert len(attention_sink_params) == 3
sink_size = int(attention_sink_params[0])
window_size = int(attention_sink_params[1])

assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
assert llm_config.export.max_seq_length == sink_size + window_size

device = "cuda" if torch.cuda.is_available() else "cpu"
manager: LLMEdgeManager = _prepare_for_llama_export(args)
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
model = manager.model.eval().to(device=device)
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)

eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

Expand All @@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
input_pos = 0
while input_pos < args.attention_sink_eval_tokens:
for text in eval_data["text"]: # pyre-ignore [16]
for text in eval_data["text"]:
tokens = tokenizer.encode(text, bos=False, eos=False)
if len(tokens) <= 0:
continue
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/export_llama_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import export_llama
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf

cs = ConfigStore.instance()
cs.store(name="llm_config", node=LlmConfig)


@hydra.main(version_base=None, config_name="llm_config")
def main(llm_config: LlmConfig) -> None:
export_llama(llm_config)
export_llama(OmegaConf.to_object(llm_config))


if __name__ == "__main__":
Expand Down
Loading
Loading