Skip to content

Commit 1954fc2

Browse files
committed
Use llm_config instead of args in export_llama functions
ghstack-source-id: e547155 Pull Request resolved: #11162
1 parent 3031363 commit 1954fc2

File tree

7 files changed

+339
-269
lines changed

7 files changed

+339
-269
lines changed

backends/arm/test/models/test_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
TosaPipelineMI,
2323
)
2424

25+
from executorch.examples.models.llama.config.llm_config_utils import (
26+
convert_args_to_llm_config,
27+
)
2528
from executorch.examples.models.llama.export_llama_lib import (
2629
build_args_parser,
2730
get_llama_model,
@@ -89,8 +92,9 @@ def prepare_model(self):
8992
]
9093
parser = build_args_parser()
9194
args = parser.parse_args(args)
95+
llm_config = convert_args_to_llm_config(args)
9296

93-
llama_model, llama_inputs, llama_meta = get_llama_model(args)
97+
llama_model, llama_inputs, llama_meta = get_llama_model(llm_config)
9498

9599
return llama_model, llama_inputs, llama_meta
96100

examples/models/llama/eval_llama_lib.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _model_call(self, inps):
164164
def gen_eval_wrapper(
165165
model_name: str,
166166
args: argparse.ArgumentParser,
167+
llm_config=None,
167168
):
168169
"""
169170
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,15 @@ def gen_eval_wrapper(
172173
Returns:
173174
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174175
"""
175-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore
176+
# If llm_config is not provided, convert args to llm_config
177+
if llm_config is None:
178+
from executorch.examples.models.llama.config.llm_config_utils import (
179+
convert_args_to_llm_config,
180+
)
181+
182+
llm_config = convert_args_to_llm_config(args)
183+
184+
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)
176185

177186
# ExecuTorch Binary Evaluation
178187
if (model := args.pte) is not None: # pyre-ignore
@@ -182,7 +191,7 @@ def gen_eval_wrapper(
182191
model=model,
183192
tokenizer=tokenizer,
184193
tokenizer_bin=tokenizer_bin,
185-
max_seq_length=args.max_seq_length, # pyre-ignore
194+
max_seq_length=llm_config.export.max_seq_length,
186195
)
187196

188197
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +200,14 @@ def gen_eval_wrapper(
191200
tokenizer=tokenizer,
192201
# Exported model takes at most (max_seq_length - 1) tokens.
193202
# Note that the eager model takes at most max_seq_length tokens.
194-
max_seq_length=args.max_seq_length - 1,
203+
max_seq_length=llm_config.export.max_seq_length - 1,
195204
)
196205

197-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
206+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(
207+
llm_config
208+
)
198209
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
210+
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
200211

201212
if len(quantizers) != 0:
202213
manager = manager.export().pt2e_quantize(quantizers)
@@ -208,9 +219,9 @@ def gen_eval_wrapper(
208219
return GraphModuleEvalWrapper(
209220
model=model,
210221
tokenizer=tokenizer,
211-
max_seq_length=args.max_seq_length,
212-
use_kv_cache=args.use_kv_cache, # pyre-ignore
213-
enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore
222+
max_seq_length=llm_config.export.max_seq_length,
223+
use_kv_cache=llm_config.model.use_kv_cache,
224+
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
214225
)
215226
else:
216227
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +245,8 @@ def gen_eval_wrapper(
234245
return EagerEvalWrapper(
235246
model=model,
236247
tokenizer=tokenizer,
237-
max_seq_length=args.max_seq_length,
238-
use_kv_cache=args.use_kv_cache,
248+
max_seq_length=llm_config.export.max_seq_length,
249+
use_kv_cache=llm_config.model.use_kv_cache,
239250
)
240251

241252

@@ -296,12 +307,18 @@ def eval_llama(
296307
model_name: str,
297308
args: argparse.ArgumentParser,
298309
) -> None:
310+
# Convert args to LlmConfig
311+
from executorch.examples.models.llama.config.llm_config_utils import (
312+
convert_args_to_llm_config,
313+
)
314+
315+
llm_config = convert_args_to_llm_config(args)
316+
299317
# Generate the eval wrapper
300-
eval_wrapper = gen_eval_wrapper(model_name, args)
318+
eval_wrapper = gen_eval_wrapper(model_name, args, llm_config)
301319

302320
# Needed for loading mmlu dataset.
303321
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304-
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305322
if args.tasks and "mmlu" in args.tasks:
306323
import datasets
307324

@@ -312,8 +329,8 @@ def eval_llama(
312329
eval_results = simple_evaluate(
313330
model=eval_wrapper,
314331
tasks=args.tasks,
315-
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316-
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
332+
num_fewshot=args.num_fewshot,
333+
limit=args.limit,
317334
)
318335

319336
for task, res in eval_results["results"].items():
@@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326343
327344
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328345
"""
329-
assert args.use_attention_sink is not None # pyre-ignore [16]
330-
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
331-
attention_sink_params = args.use_attention_sink.split(",")
346+
# Convert args to LlmConfig
347+
from executorch.examples.models.llama.config.llm_config_utils import (
348+
convert_args_to_llm_config,
349+
)
350+
351+
llm_config = convert_args_to_llm_config(args)
352+
353+
assert llm_config.model.use_attention_sink is not None
354+
assert args.attention_sink_eval_tokens > 0
355+
attention_sink_params = llm_config.model.use_attention_sink.split(",")
332356
assert len(attention_sink_params) == 3
333357
sink_size = int(attention_sink_params[0])
334358
window_size = int(attention_sink_params[1])
335359

336-
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
360+
assert llm_config.export.max_seq_length == sink_size + window_size
337361

338362
device = "cuda" if torch.cuda.is_available() else "cpu"
339-
manager: LLMEdgeManager = _prepare_for_llama_export(args)
363+
manager: LLMEdgeManager = _prepare_for_llama_export(llm_config)
340364
model = manager.model.eval().to(device=device)
341-
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
365+
tokenizer = get_tokenizer(llm_config.base.tokenizer_path)
342366

343367
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
344368

@@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347371
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
348372
input_pos = 0
349373
while input_pos < args.attention_sink_eval_tokens:
350-
for text in eval_data["text"]: # pyre-ignore [16]
374+
for text in eval_data["text"]:
351375
tokens = tokenizer.encode(text, bos=False, eos=False)
352376
if len(tokens) <= 0:
353377
continue

0 commit comments

Comments
 (0)