Skip to content

Commit ee0dec9

Browse files
authored
test(torchtitan): add model unit tests for TorchTitan backend (#256)
1 parent b237c12 commit ee0dec9

12 files changed

+331
-43
lines changed

examples/torchtitan/configs/MI300X/llama3.1_405B-pretrain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ modules:
1818
training:
1919
local_batch_size: 2
2020

21+
metrics:
22+
log_freq: 1
23+
2124
optimizer:
2225
lr: 8.0e-5
2326

examples/torchtitan/configs/MI300X/llama3.1_70B-BF16-pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ modules:
2020
warmup_steps: 10
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424

2525
training:
2626
local_batch_size: 4

examples/torchtitan/configs/MI300X/llama3.1_70B-FP8-pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ modules:
2020
warmup_steps: 10
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424

2525
training:
2626
local_batch_size: 3

examples/torchtitan/configs/MI300X/llama3.1_8B-BF16-pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ modules:
1616
stderr_sink_level: INFO
1717

1818
metrics:
19-
log_freq: 10
19+
log_freq: 1
2020
enable_wandb: false
2121

2222
lr_scheduler:

examples/torchtitan/configs/MI300X/llama3.1_8B-FP8-pretrain.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ modules:
2020
warmup_steps: 10
2121

2222
metrics:
23-
log_freq: 10
23+
log_freq: 1
2424

2525
training:
2626
local_batch_size: 4

examples/torchtitan/configs/MI300X/qwen3_0.6B-pretrain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ modules:
2020
lr_scheduler:
2121
warmup_steps: 2 # lr scheduler warm up, 20% total steps
2222

23+
metrics:
24+
log_freq: 1
25+
2326
training:
2427
local_batch_size: 4
2528
seq_len: 4096

examples/torchtitan/configs/MI300X/qwen3_1.7B-pretrain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ modules:
1212
model: qwen3_1.7b.yaml
1313
overrides:
1414

15+
metrics:
16+
log_freq: 1
17+
1518
optimizer:
1619
name: "AdamW"
1720
lr: 8.0e-4

examples/torchtitan/configs/MI300X/qwen3_32B-pretrain.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ modules:
1212
model: qwen3_32b.yaml
1313
overrides:
1414

15+
metrics:
16+
log_freq: 1
17+
1518
optimizer:
1619
name: "AdamW"
1720
lr: 3.0e-4

examples/torchtitan/prepare.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from pathlib import Path
1313
from typing import Optional
1414

15-
from huggingface_hub import snapshot_download
16-
from requests.exceptions import HTTPError
17-
1815
from examples.scripts.utils import (
1916
get_env_case_insensitive,
2017
get_node_rank,
@@ -25,22 +22,6 @@
2522
from primus.core.launcher.parser import PrimusParser
2623

2724

28-
def hf_download(repo_id: str, local_dir: str, hf_token: Optional[str] = None) -> None:
29-
try:
30-
snapshot_download(
31-
repo_id=repo_id,
32-
local_dir=local_dir,
33-
local_dir_use_symlinks=False,
34-
token=hf_token,
35-
ignore_patterns=["*.bin", "*.pt", "*.safetensors"],
36-
)
37-
except HTTPError as e:
38-
if e.response.status_code == 401:
39-
log_error_and_exit("You need to pass a valid `HF_TOKEN` to download private checkpoints.")
40-
else:
41-
raise e
42-
43-
4425
def parse_args():
4526
parser = argparse.ArgumentParser(description="Prepare Primus environment")
4627
parser.add_argument("--primus_path", type=str, required=True, help="Root path to the Primus project")
@@ -85,6 +66,75 @@ def resolve_backend_path(
8566
return path
8667

8768

69+
def run_titan_hf_download(
70+
torchtitan_path: Path, repo_id: str, local_dir: Path, hf_token: Optional[str] = None
71+
):
72+
"""Use Titan's own download_hf_assets.py to fetch tokenizer/model assets."""
73+
script_path = torchtitan_path / "scripts" / "download_hf_assets.py"
74+
if not script_path.is_file():
75+
log_error_and_exit(f"TorchTitan script not found: {script_path}")
76+
77+
cmd = [
78+
"python",
79+
str(script_path),
80+
"--repo_id",
81+
repo_id,
82+
"--assets",
83+
"tokenizer",
84+
"--local_dir",
85+
str(local_dir),
86+
]
87+
env = os.environ.copy()
88+
if hf_token:
89+
env["HF_TOKEN"] = hf_token
90+
91+
log_info(f"[rank0] Running Titan HF downloader:\n {' '.join(cmd)}")
92+
ret = subprocess.run(cmd, env=env, cwd=torchtitan_path)
93+
if ret.returncode != 0:
94+
log_error_and_exit(f"TorchTitan HF download failed with code {ret.returncode}")
95+
96+
97+
def resolve_hf_assets_path(data_path: Path, hf_assets_value: str) -> tuple[str, Path, bool]:
98+
"""
99+
Resolve HuggingFace asset source — supports both repo IDs and local paths.
100+
101+
Args:
102+
data_path (Path):
103+
Base data directory (e.g., /data/primus_data).
104+
hf_assets_value (str):
105+
Can be either:
106+
- A HuggingFace repo ID (e.g., "meta-llama/Llama-3.1-70B")
107+
- A local directory path (e.g., "/data/primus_data/torchtitan/Llama-3.1-70B")
108+
109+
Returns:
110+
(repo_or_path, local_dir, need_download)
111+
repo_or_path: str — repo_id if remote; same path if local
112+
local_dir: Path — where assets are or will be located
113+
need_download: bool — True if download is required
114+
115+
Behavior:
116+
1. If hf_assets_value is an existing directory path:
117+
→ Treat it as an already downloaded local path.
118+
2. If it is not an existing directory (likely a repo_id):
119+
→ Derive the local target dir as
120+
data_path / "torchtitan" / <last_component_of_repo_id>
121+
→ Mark need_download=True.
122+
"""
123+
path_candidate = Path(hf_assets_value).expanduser()
124+
125+
# Case 1: already-downloaded local directory
126+
if path_candidate.exists() and path_candidate.is_dir():
127+
log_info(f"Detected local HF assets path: {path_candidate}")
128+
return hf_assets_value, path_candidate.resolve(), False
129+
130+
# Case 2: repo_id (e.g., meta-llama/Llama-3.1-70B) → need to download
131+
repo_id = hf_assets_value
132+
repo_name = Path(repo_id).name # last segment, e.g., Llama-3.1-70B
133+
local_dir = data_path / "torchtitan" / repo_name
134+
log_info(f"Resolved HF repo_id={repo_id}, local_dir={local_dir}")
135+
return repo_id, local_dir, True
136+
137+
88138
def main():
89139
args = parse_args()
90140

@@ -120,28 +170,35 @@ def main():
120170
if not hasattr(pre_trainer_cfg.model, "hf_assets_path") or not pre_trainer_cfg.model.hf_assets_path:
121171
log_error_and_exit("Missing required field: pre_trainer.model.tokenizer_path")
122172

123-
hf_assets_path = pre_trainer_cfg.model.hf_assets_path
124-
125-
full_path = data_path / "torchtitan" / hf_assets_path.lstrip("/")
173+
hf_assets_value = pre_trainer_cfg.model.hf_assets_path
174+
repo_id, local_dir, need_download = resolve_hf_assets_path(data_path, hf_assets_value)
175+
tokenizer_file = local_dir / "tokenizer.json"
126176

127-
tokenizer_test_file = full_path / "tokenizer.json"
128-
if not tokenizer_test_file.is_file():
177+
if need_download:
178+
# Remote repo_id case — download via Titan script
129179
hf_token = os.environ.get("HF_TOKEN")
130180
if not hf_token:
131-
log_error_and_exit("HF_TOKEN not set. Please export HF_TOKEN.")
181+
log_error_and_exit("HF_TOKEN not set. Please export HF_TOKEN before running prepare.")
132182

133183
if get_node_rank() == 0:
134-
log_info(f"Downloading HF assets for tokenizer to {full_path} ...")
135-
full_path.mkdir(parents=True, exist_ok=True)
136-
hf_download(repo_id=hf_assets_path, local_dir=str(full_path), hf_token=hf_token)
184+
if not tokenizer_file.exists():
185+
log_info(f"Downloading HF assets from repo={repo_id} into {local_dir} ...")
186+
parent_dir = local_dir.parent
187+
parent_dir.mkdir(parents=True, exist_ok=True)
188+
run_titan_hf_download(torchtitan_path, repo_id, parent_dir, hf_token)
189+
else:
190+
log_info(f"Tokenizer assets already exist: {tokenizer_file}")
137191
else:
138-
log_info(f"Rank {get_node_rank()} waiting for tokenizer download ...")
139-
while not tokenizer_test_file.exists():
192+
# Other ranks wait until the file is available
193+
log_info(f"[rank{get_node_rank()}] waiting for tokenizer download ...")
194+
while not tokenizer_file.exists():
140195
time.sleep(5)
141196
else:
142-
log_info(f"Tokenizer assets already exist: {tokenizer_test_file}")
197+
# Local path case — skip download
198+
log_info(f"HF assets already available locally at {local_dir}")
143199

144-
write_patch_args(patch_args_file, "train_args", {"model.hf_assets_path": str(full_path)})
200+
# Pass resolved path to training phase
201+
write_patch_args(patch_args_file, "train_args", {"model.hf_assets_path": str(local_dir)})
145202
write_patch_args(patch_args_file, "train_args", {"backend_path": str(torchtitan_path)})
146203
write_patch_args(patch_args_file, "torchrun_args", {"local-ranks-filter": "1"})
147204

tests/run_unit_tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,20 @@ def get_all_unit_tests():
2020
cur_dir = "./tests"
2121
unit_tests = {}
2222

23+
EXCLUDE_UNIT_TESTS = [
24+
"unit_tests/megatron/cco/test_tp_overlap.py",
25+
]
26+
2327
for root, dirs, files in os.walk(cur_dir):
2428
for file_name in files:
2529
if not file_name.endswith(".py") or not file_name.startswith("test_"):
2630
continue
2731

32+
# Construct relative path from tests/
33+
rel_path = os.path.relpath(os.path.join(root, file_name), start=cur_dir)
34+
if rel_path in EXCLUDE_UNIT_TESTS:
35+
continue
36+
2837
if file_name not in DISTRIBUTED_UNIT_TESTS:
2938
unit_tests[os.path.join(root, file_name)] = 1
3039
else:

0 commit comments

Comments
 (0)