@@ -164,6 +164,7 @@ def _model_call(self, inps):
164
164
def gen_eval_wrapper (
165
165
model_name : str ,
166
166
args : argparse .ArgumentParser ,
167
+ llm_config = None ,
167
168
):
168
169
"""
169
170
Generates a wrapper interface around the provided model and tokenizer for
@@ -172,7 +173,15 @@ def gen_eval_wrapper(
172
173
Returns:
173
174
eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
174
175
"""
175
- tokenizer = get_tokenizer (args .tokenizer_path ) # pyre-ignore
176
+ # If llm_config is not provided, convert args to llm_config
177
+ if llm_config is None :
178
+ from executorch .examples .models .llama .config .llm_config_utils import (
179
+ convert_args_to_llm_config ,
180
+ )
181
+
182
+ llm_config = convert_args_to_llm_config (args )
183
+
184
+ tokenizer = get_tokenizer (llm_config .base .tokenizer_path )
176
185
177
186
# ExecuTorch Binary Evaluation
178
187
if (model := args .pte ) is not None : # pyre-ignore
@@ -182,7 +191,7 @@ def gen_eval_wrapper(
182
191
model = model ,
183
192
tokenizer = tokenizer ,
184
193
tokenizer_bin = tokenizer_bin ,
185
- max_seq_length = args . max_seq_length , # pyre-ignore
194
+ max_seq_length = llm_config . export . max_seq_length ,
186
195
)
187
196
188
197
# ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
@@ -191,12 +200,14 @@ def gen_eval_wrapper(
191
200
tokenizer = tokenizer ,
192
201
# Exported model takes at most (max_seq_length - 1) tokens.
193
202
# Note that the eager model takes at most max_seq_length tokens.
194
- max_seq_length = args .max_seq_length - 1 ,
203
+ max_seq_length = llm_config . export .max_seq_length - 1 ,
195
204
)
196
205
197
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
206
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (
207
+ llm_config
208
+ )
198
209
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
199
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
210
+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
200
211
201
212
if len (quantizers ) != 0 :
202
213
manager = manager .export ().pt2e_quantize (quantizers )
@@ -208,9 +219,9 @@ def gen_eval_wrapper(
208
219
return GraphModuleEvalWrapper (
209
220
model = model ,
210
221
tokenizer = tokenizer ,
211
- max_seq_length = args .max_seq_length ,
212
- use_kv_cache = args . use_kv_cache , # pyre-ignore
213
- enable_dynamic_shape = args . enable_dynamic_shape , # pyre-ignore
222
+ max_seq_length = llm_config . export .max_seq_length ,
223
+ use_kv_cache = llm_config . model . use_kv_cache ,
224
+ enable_dynamic_shape = llm_config . model . enable_dynamic_shape ,
214
225
)
215
226
else :
216
227
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
@@ -234,8 +245,8 @@ def gen_eval_wrapper(
234
245
return EagerEvalWrapper (
235
246
model = model ,
236
247
tokenizer = tokenizer ,
237
- max_seq_length = args .max_seq_length ,
238
- use_kv_cache = args .use_kv_cache ,
248
+ max_seq_length = llm_config . export .max_seq_length ,
249
+ use_kv_cache = llm_config . model .use_kv_cache ,
239
250
)
240
251
241
252
@@ -296,12 +307,18 @@ def eval_llama(
296
307
model_name : str ,
297
308
args : argparse .ArgumentParser ,
298
309
) -> None :
310
+ # Convert args to LlmConfig
311
+ from executorch .examples .models .llama .config .llm_config_utils import (
312
+ convert_args_to_llm_config ,
313
+ )
314
+
315
+ llm_config = convert_args_to_llm_config (args )
316
+
299
317
# Generate the eval wrapper
300
- eval_wrapper = gen_eval_wrapper (model_name , args )
318
+ eval_wrapper = gen_eval_wrapper (model_name , args , llm_config )
301
319
302
320
# Needed for loading mmlu dataset.
303
321
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
304
- # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
305
322
if args .tasks and "mmlu" in args .tasks :
306
323
import datasets
307
324
@@ -312,8 +329,8 @@ def eval_llama(
312
329
eval_results = simple_evaluate (
313
330
model = eval_wrapper ,
314
331
tasks = args .tasks ,
315
- num_fewshot = args .num_fewshot , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
316
- limit = args .limit , # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
332
+ num_fewshot = args .num_fewshot ,
333
+ limit = args .limit ,
317
334
)
318
335
319
336
for task , res in eval_results ["results" ].items ():
@@ -326,19 +343,26 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
326
343
327
344
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
328
345
"""
329
- assert args .use_attention_sink is not None # pyre-ignore [16]
330
- assert args .attention_sink_eval_tokens > 0 # pyre-ignore [16]
331
- attention_sink_params = args .use_attention_sink .split ("," )
346
+ # Convert args to LlmConfig
347
+ from executorch .examples .models .llama .config .llm_config_utils import (
348
+ convert_args_to_llm_config ,
349
+ )
350
+
351
+ llm_config = convert_args_to_llm_config (args )
352
+
353
+ assert llm_config .model .use_attention_sink is not None
354
+ assert args .attention_sink_eval_tokens > 0
355
+ attention_sink_params = llm_config .model .use_attention_sink .split ("," )
332
356
assert len (attention_sink_params ) == 3
333
357
sink_size = int (attention_sink_params [0 ])
334
358
window_size = int (attention_sink_params [1 ])
335
359
336
- assert args . max_seq_length == sink_size + window_size # pyre-ignore [16]
360
+ assert llm_config . export . max_seq_length == sink_size + window_size
337
361
338
362
device = "cuda" if torch .cuda .is_available () else "cpu"
339
- manager : LLMEdgeManager = _prepare_for_llama_export (args )
363
+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
340
364
model = manager .model .eval ().to (device = device )
341
- tokenizer = get_tokenizer (args . tokenizer_path ) # pyre-ignore [16]
365
+ tokenizer = get_tokenizer (llm_config . base . tokenizer_path )
342
366
343
367
eval_data = load_dataset ("wikitext" , "wikitext-2-raw-v1" , split = "test" )
344
368
@@ -347,7 +371,7 @@ def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParse
347
371
progress_bar = tqdm (total = args .attention_sink_eval_tokens )
348
372
input_pos = 0
349
373
while input_pos < args .attention_sink_eval_tokens :
350
- for text in eval_data ["text" ]: # pyre-ignore [16]
374
+ for text in eval_data ["text" ]:
351
375
tokens = tokenizer .encode (text , bos = False , eos = False )
352
376
if len (tokens ) <= 0 :
353
377
continue
0 commit comments