Skip to content

Commit cdcac9f

Browse files
committedMay 22, 2025
Introduce hydra framework with backwards compatibility
ghstack-source-id: 6009f90 Pull Request resolved: #11029
1 parent 0340494 commit cdcac9f

File tree

6 files changed

+123
-11
lines changed

6 files changed

+123
-11
lines changed
 
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import argparse
9+
10+
from executorch.examples.models.llama.config.llm_config import LlmConfig
11+
12+
13+
def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
14+
"""
15+
To support legacy purposes, this function converts CLI args from
16+
argparse to an LlmConfig, which is used by the LLM export process.
17+
"""
18+
llm_config = LlmConfig()
19+
20+
# TODO: conversion code.
21+
22+
return llm_config

‎examples/models/llama/export_llama.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,50 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# Example script for exporting Llama2 to flatbuffer
8-
9-
import logging
10-
117
# force=True to ensure logging while in debugger. Set up logger before any
128
# other imports.
9+
import logging
10+
1311
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1412
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)
1513

14+
import argparse
15+
import runpy
1616
import sys
1717

1818
import torch
1919

20-
from .export_llama_lib import build_args_parser, export_llama
21-
2220
sys.setrecursionlimit(4096)
2321

2422

23+
def parse_hydra_arg():
24+
"""First parse out the arg for whether to use Hydra or the old CLI."""
25+
parser = argparse.ArgumentParser(add_help=True)
26+
parser.add_argument("--hydra", action="store_true")
27+
args, remaining = parser.parse_known_args()
28+
return args.hydra, remaining
29+
30+
2531
def main() -> None:
2632
seed = 42
2733
torch.manual_seed(seed)
28-
parser = build_args_parser()
29-
args = parser.parse_args()
30-
export_llama(args)
34+
35+
use_hydra, remaining_args = parse_hydra_arg()
36+
if use_hydra:
37+
# The import runs the main function of export_llama_hydra with the remaining args
38+
# under the Hydra framework.
39+
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
40+
print(f"running with {sys.argv}")
41+
runpy.run_module(
42+
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
43+
)
44+
else:
45+
# Use the legacy version of the export_llama script which uses argsparse.
46+
from executorch.examples.models.llama.export_llama_args import (
47+
main as export_llama_args_main,
48+
)
49+
50+
export_llama_args_main(remaining_args)
3151

3252

3353
if __name__ == "__main__":
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama with the legacy argparse setup.
9+
"""
10+
11+
from .export_llama_lib import build_args_parser, export_llama
12+
13+
14+
def main(args) -> None:
15+
parser = build_args_parser()
16+
args = parser.parse_args(args)
17+
export_llama(args)
18+
19+
20+
if __name__ == "__main__":
21+
main()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Run export_llama using the new Hydra CLI.
9+
"""
10+
11+
import hydra
12+
13+
from executorch.examples.models.llama.config.llm_config import LlmConfig
14+
from executorch.examples.models.llama.export_llama_lib import export_llama
15+
from hydra.core.config_store import ConfigStore
16+
17+
cs = ConfigStore.instance()
18+
cs.store(name="llm_config", node=LlmConfig)
19+
20+
21+
@hydra.main(version_base=None, config_name="llm_config")
22+
def main(llm_config: LlmConfig) -> None:
23+
export_llama(llm_config)
24+
25+
26+
if __name__ == "__main__":
27+
main()

‎examples/models/llama/export_llama_lib.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from executorch.devtools.backend_debug import print_delegation_info
2929

3030
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
31+
32+
from executorch.examples.models.llama.config.llm_config_utils import (
33+
convert_args_to_llm_config,
34+
)
3135
from executorch.examples.models.llama.hf_download import (
3236
download_and_convert_hf_checkpoint,
3337
)
@@ -51,6 +55,7 @@
5155
get_vulkan_quantizer,
5256
)
5357
from executorch.util.activation_memory_profiler import generate_memory_trace
58+
from omegaconf.dictconfig import DictConfig
5459

5560
from ..model_factory import EagerModelFactory
5661
from .source_transformation.apply_spin_quant_r1_r2 import (
@@ -568,7 +573,24 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
568573
return return_val
569574

570575

571-
def export_llama(args) -> str:
576+
def export_llama(
577+
export_options: Union[argparse.Namespace, DictConfig],
578+
) -> str:
579+
if isinstance(export_options, argparse.Namespace):
580+
# Legacy CLI.
581+
args = export_options
582+
llm_config = convert_args_to_llm_config(export_options) # noqa: F841
583+
elif isinstance(export_options, DictConfig):
584+
# Hydra CLI.
585+
llm_config = export_options # noqa: F841
586+
pass
587+
else:
588+
raise ValueError(
589+
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
590+
)
591+
592+
# TODO: refactor rest of export_llama to use llm_config instead of args.
593+
572594
# If a checkpoint isn't provided for an HF OSS model, download and convert the
573595
# weights first.
574596
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:

‎examples/models/llama/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Install tokenizers for hf .json tokenizer.
1111
# Install snakeviz for cProfile flamegraph
1212
# Install lm-eval for Model Evaluation with lm-evalution-harness.
13-
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
13+
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
1414

1515
# Call the install helper for further setup
1616
python examples/models/llama/install_requirement_helper.py

0 commit comments

Comments
 (0)