|
12 | 12 | from pathlib import Path |
13 | 13 | from typing import Optional |
14 | 14 |
|
15 | | -from huggingface_hub import snapshot_download |
16 | | -from requests.exceptions import HTTPError |
17 | | - |
18 | 15 | from examples.scripts.utils import ( |
19 | 16 | get_env_case_insensitive, |
20 | 17 | get_node_rank, |
|
25 | 22 | from primus.core.launcher.parser import PrimusParser |
26 | 23 |
|
27 | 24 |
|
28 | | -def hf_download(repo_id: str, local_dir: str, hf_token: Optional[str] = None) -> None: |
29 | | - try: |
30 | | - snapshot_download( |
31 | | - repo_id=repo_id, |
32 | | - local_dir=local_dir, |
33 | | - local_dir_use_symlinks=False, |
34 | | - token=hf_token, |
35 | | - ignore_patterns=["*.bin", "*.pt", "*.safetensors"], |
36 | | - ) |
37 | | - except HTTPError as e: |
38 | | - if e.response.status_code == 401: |
39 | | - log_error_and_exit("You need to pass a valid `HF_TOKEN` to download private checkpoints.") |
40 | | - else: |
41 | | - raise e |
42 | | - |
43 | | - |
44 | 25 | def parse_args(): |
45 | 26 | parser = argparse.ArgumentParser(description="Prepare Primus environment") |
46 | 27 | parser.add_argument("--primus_path", type=str, required=True, help="Root path to the Primus project") |
@@ -85,6 +66,75 @@ def resolve_backend_path( |
85 | 66 | return path |
86 | 67 |
|
87 | 68 |
|
| 69 | +def run_titan_hf_download( |
| 70 | + torchtitan_path: Path, repo_id: str, local_dir: Path, hf_token: Optional[str] = None |
| 71 | +): |
| 72 | + """Use Titan's own download_hf_assets.py to fetch tokenizer/model assets.""" |
| 73 | + script_path = torchtitan_path / "scripts" / "download_hf_assets.py" |
| 74 | + if not script_path.is_file(): |
| 75 | + log_error_and_exit(f"TorchTitan script not found: {script_path}") |
| 76 | + |
| 77 | + cmd = [ |
| 78 | + "python", |
| 79 | + str(script_path), |
| 80 | + "--repo_id", |
| 81 | + repo_id, |
| 82 | + "--assets", |
| 83 | + "tokenizer", |
| 84 | + "--local_dir", |
| 85 | + str(local_dir), |
| 86 | + ] |
| 87 | + env = os.environ.copy() |
| 88 | + if hf_token: |
| 89 | + env["HF_TOKEN"] = hf_token |
| 90 | + |
| 91 | + log_info(f"[rank0] Running Titan HF downloader:\n {' '.join(cmd)}") |
| 92 | + ret = subprocess.run(cmd, env=env, cwd=torchtitan_path) |
| 93 | + if ret.returncode != 0: |
| 94 | + log_error_and_exit(f"TorchTitan HF download failed with code {ret.returncode}") |
| 95 | + |
| 96 | + |
| 97 | +def resolve_hf_assets_path(data_path: Path, hf_assets_value: str) -> tuple[str, Path, bool]: |
| 98 | + """ |
| 99 | + Resolve HuggingFace asset source — supports both repo IDs and local paths. |
| 100 | +
|
| 101 | + Args: |
| 102 | + data_path (Path): |
| 103 | + Base data directory (e.g., /data/primus_data). |
| 104 | + hf_assets_value (str): |
| 105 | + Can be either: |
| 106 | + - A HuggingFace repo ID (e.g., "meta-llama/Llama-3.1-70B") |
| 107 | + - A local directory path (e.g., "/data/primus_data/torchtitan/Llama-3.1-70B") |
| 108 | +
|
| 109 | + Returns: |
| 110 | + (repo_or_path, local_dir, need_download) |
| 111 | + repo_or_path: str — repo_id if remote; same path if local |
| 112 | + local_dir: Path — where assets are or will be located |
| 113 | + need_download: bool — True if download is required |
| 114 | +
|
| 115 | + Behavior: |
| 116 | + 1. If hf_assets_value is an existing directory path: |
| 117 | + → Treat it as an already downloaded local path. |
| 118 | + 2. If it is not an existing directory (likely a repo_id): |
| 119 | + → Derive the local target dir as |
| 120 | + data_path / "torchtitan" / <last_component_of_repo_id> |
| 121 | + → Mark need_download=True. |
| 122 | + """ |
| 123 | + path_candidate = Path(hf_assets_value).expanduser() |
| 124 | + |
| 125 | + # Case 1: already-downloaded local directory |
| 126 | + if path_candidate.exists() and path_candidate.is_dir(): |
| 127 | + log_info(f"Detected local HF assets path: {path_candidate}") |
| 128 | + return hf_assets_value, path_candidate.resolve(), False |
| 129 | + |
| 130 | + # Case 2: repo_id (e.g., meta-llama/Llama-3.1-70B) → need to download |
| 131 | + repo_id = hf_assets_value |
| 132 | + repo_name = Path(repo_id).name # last segment, e.g., Llama-3.1-70B |
| 133 | + local_dir = data_path / "torchtitan" / repo_name |
| 134 | + log_info(f"Resolved HF repo_id={repo_id}, local_dir={local_dir}") |
| 135 | + return repo_id, local_dir, True |
| 136 | + |
| 137 | + |
88 | 138 | def main(): |
89 | 139 | args = parse_args() |
90 | 140 |
|
@@ -120,28 +170,35 @@ def main(): |
120 | 170 | if not hasattr(pre_trainer_cfg.model, "hf_assets_path") or not pre_trainer_cfg.model.hf_assets_path: |
121 | 171 | log_error_and_exit("Missing required field: pre_trainer.model.tokenizer_path") |
122 | 172 |
|
123 | | - hf_assets_path = pre_trainer_cfg.model.hf_assets_path |
124 | | - |
125 | | - full_path = data_path / "torchtitan" / hf_assets_path.lstrip("/") |
| 173 | + hf_assets_value = pre_trainer_cfg.model.hf_assets_path |
| 174 | + repo_id, local_dir, need_download = resolve_hf_assets_path(data_path, hf_assets_value) |
| 175 | + tokenizer_file = local_dir / "tokenizer.json" |
126 | 176 |
|
127 | | - tokenizer_test_file = full_path / "tokenizer.json" |
128 | | - if not tokenizer_test_file.is_file(): |
| 177 | + if need_download: |
| 178 | + # Remote repo_id case — download via Titan script |
129 | 179 | hf_token = os.environ.get("HF_TOKEN") |
130 | 180 | if not hf_token: |
131 | | - log_error_and_exit("HF_TOKEN not set. Please export HF_TOKEN.") |
| 181 | + log_error_and_exit("HF_TOKEN not set. Please export HF_TOKEN before running prepare.") |
132 | 182 |
|
133 | 183 | if get_node_rank() == 0: |
134 | | - log_info(f"Downloading HF assets for tokenizer to {full_path} ...") |
135 | | - full_path.mkdir(parents=True, exist_ok=True) |
136 | | - hf_download(repo_id=hf_assets_path, local_dir=str(full_path), hf_token=hf_token) |
| 184 | + if not tokenizer_file.exists(): |
| 185 | + log_info(f"Downloading HF assets from repo={repo_id} into {local_dir} ...") |
| 186 | + parent_dir = local_dir.parent |
| 187 | + parent_dir.mkdir(parents=True, exist_ok=True) |
| 188 | + run_titan_hf_download(torchtitan_path, repo_id, parent_dir, hf_token) |
| 189 | + else: |
| 190 | + log_info(f"Tokenizer assets already exist: {tokenizer_file}") |
137 | 191 | else: |
138 | | - log_info(f"Rank {get_node_rank()} waiting for tokenizer download ...") |
139 | | - while not tokenizer_test_file.exists(): |
| 192 | + # Other ranks wait until the file is available |
| 193 | + log_info(f"[rank{get_node_rank()}] waiting for tokenizer download ...") |
| 194 | + while not tokenizer_file.exists(): |
140 | 195 | time.sleep(5) |
141 | 196 | else: |
142 | | - log_info(f"Tokenizer assets already exist: {tokenizer_test_file}") |
| 197 | + # Local path case — skip download |
| 198 | + log_info(f"HF assets already available locally at {local_dir}") |
143 | 199 |
|
144 | | - write_patch_args(patch_args_file, "train_args", {"model.hf_assets_path": str(full_path)}) |
| 200 | + # Pass resolved path to training phase |
| 201 | + write_patch_args(patch_args_file, "train_args", {"model.hf_assets_path": str(local_dir)}) |
145 | 202 | write_patch_args(patch_args_file, "train_args", {"backend_path": str(torchtitan_path)}) |
146 | 203 | write_patch_args(patch_args_file, "torchrun_args", {"local-ranks-filter": "1"}) |
147 | 204 |
|
|
0 commit comments