Skip to content

Commit c7ecf7c

Browse files
committed
Support model name as a positional parameter
1 parent 422e23d commit c7ecf7c

File tree

6 files changed

+53
-21
lines changed

6 files changed

+53
-21
lines changed

cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _add_arguments_common(parser):
207207
)
208208
parser.add_argument(
209209
'--model-directory',
210-
type=str,
210+
type=Path,
211211
default='.model-artifacts',
212212
help='The directory to store downloaded model artifacts'
213213
)

download.py

+30-19
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,57 @@
66
import os
77

88
from build.convert_hf_checkpoint import convert_hf_checkpoint
9+
from build.model import model_aliases
910
from pathlib import Path
1011
from typing import Optional
1112

1213
from requests.exceptions import HTTPError
1314

14-
15-
def hf_download(
16-
repo_id: Optional[str] = None,
17-
model_dir: Optional[Path] = None,
15+
def download_and_convert(
16+
model: str,
17+
models_dir: Path,
1818
hf_token: Optional[str] = None) -> None:
1919
from huggingface_hub import snapshot_download
2020

21-
if model_dir is None:
22-
model_dir = Path(".model-artifacts/{repo_id}")
21+
if model in model_aliases:
22+
model = model_aliases[model]
23+
24+
model_dir = models_dir / model
25+
os.makedirs(model_dir, exist_ok=True)
2326

27+
# Download and store the HF model artifacts.
28+
print(f"Downloading {model} from HuggingFace...")
2429
try:
2530
snapshot_download(
26-
repo_id,
31+
model,
2732
local_dir=model_dir,
2833
local_dir_use_symlinks=False,
2934
token=hf_token,
3035
ignore_patterns="*safetensors*")
3136
except HTTPError as e:
3237
if e.response.status_code == 401:
33-
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
38+
raise RuntimeError("You need to pass a valid `--hf_token=...` to download private checkpoints.")
3439
else:
3540
raise e
3641

37-
38-
def main(args):
39-
model_dir = Path(args.model_directory) / args.model
40-
os.makedirs(model_dir, exist_ok=True)
41-
42-
# Download and store the HF model artifacts.
43-
print(f"Downloading {args.model} from HuggingFace...")
44-
hf_download(args.model, model_dir, args.hf_token)
45-
4642
# Convert the model to the torchchat format.
47-
print(f"Converting {args.model} to torchchat format...")
43+
print(f"Converting {model} to torchchat format...")
4844
convert_hf_checkpoint(
4945
model_dir=model_dir,
50-
model_name=Path(args.model),
46+
model_name=Path(model),
5147
remove_bin_files=True)
48+
49+
def is_model_downloaded(
50+
model: str,
51+
models_dir: Path) -> bool:
52+
if model in model_aliases:
53+
model = model_aliases[model]
54+
55+
model_dir = models_dir / model
56+
57+
# TODO Can we be more thorough here?
58+
return os.path.isdir(model_dir)
59+
60+
61+
def main(args):
62+
download_and_convert(args.model, args.model_directory, args.hf_token)

eval.py

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from build.model import Transformer
2222
from cli import add_arguments_for_eval, arg_init
23+
from download import download_and_convert, is_model_downloaded
2324
from generate import encode_tokens, model_forward
2425

2526
from quantize import set_precision
@@ -221,6 +222,10 @@ def main(args) -> None:
221222
222223
"""
223224

225+
# If a named model was provided and not downloaded, download it.
226+
if args.model and not is_model_downloaded(args.model, args.model_directory):
227+
download_and_convert(args.model, args.model_directory, args.hf_token)
228+
224229
builder_args = BuilderArgs.from_args(args)
225230
tokenizer_args = TokenizerArgs.from_args(args)
226231
quantize = args.quantize

export.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from build.builder import _initialize_model, BuilderArgs
1313
from cli import add_arguments_for_export, arg_init, check_args
14+
from download import download_and_convert, is_model_downloaded
1415
from export_aoti import export_model as export_model_aoti
1516

1617
from quantize import set_precision
@@ -36,6 +37,10 @@ def device_sync(device):
3637

3738

3839
def main(args):
40+
# If a named model was provided and not downloaded, download it.
41+
if args.model and not is_model_downloaded(args.model, args.model_directory):
42+
download_and_convert(args.model, args.model_directory, args.hf_token)
43+
3944
builder_args = BuilderArgs.from_args(args)
4045
quantize = args.quantize
4146

generate.py

+11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from build.model import Transformer
2727
from cli import add_arguments_for_generate, arg_init, check_args
28+
from download import download_and_convert, is_model_downloaded
2829
from quantize import set_precision
2930

3031

@@ -402,6 +403,11 @@ def _main(
402403
device_sync(device=builder_args.device)
403404
if i >= 0 and generator_args.chat_mode:
404405
prompt = input("What is your prompt? ")
406+
407+
# DEBUG DO NOT COMMIT
408+
B_INST = ""
409+
E_INST = ""
410+
405411
if chat_mode:
406412
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
407413
encoded = encode_tokens(
@@ -487,6 +493,11 @@ def callback(x):
487493

488494
def main(args):
489495
is_chat = args.subcommand == "chat"
496+
497+
# If a named model was provided and not downloaded, download it.
498+
if args.model and not is_model_downloaded(args.model, args.model_directory):
499+
download_and_convert(args.model, args.model_directory, args.hf_token)
500+
490501
builder_args = BuilderArgs.from_args(args)
491502
speculative_builder_args = BuilderArgs.from_speculative_args(args)
492503
tokenizer_args = TokenizerArgs.from_args(args)

torchchat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@
6363

6464
export_main(args)
6565
else:
66-
raise RuntimeError("Must specify valid subcommands: generate, export, eval")
66+
raise RuntimeError("Must specify a valid subcommand: download, chat, generate, export, or eval.")

0 commit comments

Comments
 (0)