-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
254 lines (209 loc) · 8.85 KB
/
Copy pathtest.py
File metadata and controls
254 lines (209 loc) · 8.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
SHREK-HRM submission test script — fully self-contained.
Runs evaluation on every (model, task) listed in config.yaml and prints an
accuracy table that matches the report (Tables II, III).
Behavior
--------
1. Reads paths and HuggingFace repo IDs from config.yaml — no hardcoded paths.
2. Downloads checkpoints from HuggingFace if the local `model/` dir is empty.
3. Downloads test data from HuggingFace if the local `data/` dir is empty.
4. For each evaluation, spawns a fresh Python subprocess (`python -c <inline>`)
that imports the matching model's `pretrain.py`, loads the checkpoint, runs
evaluation, and prints the metrics dict. test.py contains the full eval
logic — no external evaluate.py needed for any of the 3 model families
(SHREK / HRM / TRM). The subprocess auto-detects API differences (TRM has
extra `evaluators` and `cpu_group` args).
5. Parses `exact_accuracy` from each run's stdout and prints a summary table.
Hardware
--------
Requires NVIDIA GPU with CUDA 12.6 (the models use flash-attn).
Total runtime: roughly 10-20 minutes on a single GH200.
Usage
-----
python test.py
"""
import glob
import os
import re
import site
import subprocess
import sys
from pathlib import Path
# Make pip-bundled CUDA libs (libnvrtc-builtins.so.*, libcudart.so.*, etc.)
# discoverable at runtime. PyTorch wheels expect these on the dynamic loader
# path, but on some cluster nodes the system path doesn't include them. They
# always live inside site-packages/nvidia/<pkg>/lib/ when torch was pip-installed
# with its CUDA dependencies, so we discover them and prepend to LD_LIBRARY_PATH.
# This makes test.py portable across nodes (we hit a real failure on a cluster
# node that was missing libnvrtc-builtins.so.13.0 from the system path).
_site_roots = list(site.getsitepackages())
try:
_site_roots.append(site.getusersitepackages())
except Exception:
pass
_nvidia_lib_dirs = []
for _root in _site_roots:
_nvidia_lib_dirs.extend(glob.glob(os.path.join(_root, "nvidia", "*", "lib")))
if _nvidia_lib_dirs:
_existing = os.environ.get("LD_LIBRARY_PATH", "")
_parts = _nvidia_lib_dirs + ([_existing] if _existing else [])
os.environ["LD_LIBRARY_PATH"] = ":".join(_parts)
# Disable HuggingFace's xet-based downloads. The xet client has known native-library
# issues on some systems (missing .so files); the regular HTTP download path is more
# reliable. Set before importing huggingface_hub so the env var is honored.
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
# Disable torch.compile / TorchInductor. The model is small enough that compile gives
# negligible speedup at eval time, and inductor's nvrtc-based kernel JIT requires a
# matching libnvrtc-builtins.so on the host — which is missing on some cluster nodes
# (we hit `nvrtc: error: failed to open libnvrtc-builtins.so.13.0` on one GH200 node
# while a sibling node worked fine). Forcing eager mode makes test.py portable.
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")
os.environ.setdefault("TORCHINDUCTOR_DISABLE", "1")
import yaml
REPO_ROOT = Path(__file__).resolve().parent
CONFIG_PATH = REPO_ROOT / "config.yaml"
def load_config():
with open(CONFIG_PATH) as f:
return yaml.safe_load(f)
def ensure_dir_populated(local_dir: Path, hf_repo: str, repo_type: str = "model"):
"""Download `hf_repo` from HuggingFace into `local_dir` if it's empty/missing."""
local_dir.mkdir(parents=True, exist_ok=True)
if any(local_dir.iterdir()):
print(f"[ok] {local_dir} already populated, skipping download.")
return
print(f"[download] {hf_repo} -> {local_dir} (this may take a while)")
try:
from huggingface_hub import snapshot_download
except ImportError:
sys.exit(
"ERROR: huggingface_hub not installed.\n"
" pip install huggingface_hub\n"
"Then re-run test.py."
)
snapshot_download(
repo_id=hf_repo,
repo_type=repo_type,
local_dir=str(local_dir),
)
# Inline evaluation script — runs as `python -c` subprocess so each call gets a
# fresh CUDA context and clean module state. test.py only evaluates SHREK
# models, so this script targets SHREK's pretrain.py API directly.
INLINE_EVAL = r"""
import os, sys, yaml, torch
# Force-disable TorchDynamo / TorchInductor before importing pretrain.py.
# pretrain.py wraps the model in torch.compile() inside init_train_state(),
# which triggers nvrtc-based kernel JIT at first forward pass. On cluster
# nodes that don't ship libnvrtc-builtins.so.13.0 this fails. Disabling
# dynamo here turns torch.compile into a no-op so eval runs in pure eager
# mode regardless of the host's CUDA toolkit layout.
torch._dynamo.config.disable = True
torch._dynamo.config.suppress_errors = True
sys.path.insert(0, os.getcwd())
from pretrain import PretrainConfig, init_train_state, create_dataloader, evaluate as _evaluate
ckpt = sys.argv[1]
ckpt_dir = os.path.dirname(ckpt)
with open(os.path.join(ckpt_dir, 'all_config.yaml')) as f:
config = PretrainConfig(**yaml.safe_load(f))
config.eval_save_outputs = []
config.checkpoint_path = ckpt_dir
train_loader, train_metadata = create_dataloader(
config, 'train',
test_set_mode=False, epochs_per_iter=1,
global_batch_size=config.global_batch_size,
rank=0, world_size=1,
)
eval_loader, eval_metadata = create_dataloader(
config, 'test',
test_set_mode=True, epochs_per_iter=1,
global_batch_size=config.global_batch_size,
rank=0, world_size=1,
)
train_state = init_train_state(config, train_metadata, world_size=1)
# Load checkpoint, unwrap torch.compile prefix if present.
state = torch.load(ckpt, map_location='cuda')
try:
train_state.model.load_state_dict(state, assign=True)
except Exception:
train_state.model.load_state_dict(
{k.removeprefix('_orig_mod.'): v for k, v in state.items()},
assign=True,
)
train_state.step = 0
fname = os.path.basename(ckpt)
if fname.startswith('step_'):
train_state.step = int(fname.removeprefix('step_'))
train_state.model.eval()
print('Starting evaluation', flush=True)
metrics = _evaluate(
config, train_state, eval_loader, eval_metadata,
rank=0, world_size=1,
)
if metrics is not None:
print(metrics)
"""
def evaluate_checkpoint(model_code_dir: Path, checkpoint_path: Path):
"""Run an inline `python -c` subprocess that evaluates the checkpoint.
Returns the parsed `exact_accuracy` (float in [0, 1]) or None on failure.
"""
if not checkpoint_path.exists():
print(f"[skip] checkpoint not found: {checkpoint_path}")
return None
print(f"[eval] {checkpoint_path}")
result = subprocess.run(
[sys.executable, "-u", "-c", INLINE_EVAL, str(checkpoint_path)],
cwd=str(model_code_dir),
capture_output=True,
text=True,
env={**os.environ, "OMP_NUM_THREADS": "8"},
)
output = result.stdout + result.stderr
# The metrics dict prints as Python repr ending with
# `{'all': {..., 'exact_accuracy': Y, ...}}`. Match the value directly.
m = re.search(r"['\"]exact_accuracy['\"]\s*:\s*([\d\.eE+-]+)", output)
if m:
return float(m.group(1))
print(" [warn] could not parse exact_accuracy from inline eval output")
print(" --- last 20 lines of output ---")
for line in output.splitlines()[-20:]:
print(f" {line}")
return None
def print_results(results):
print()
print("=" * 60)
print(f"{'Model':<22} {'Task':<18} {'Accuracy':>10}")
print("-" * 60)
for r in results:
actual = f"{r['actual']:.1%}" if r["actual"] is not None else "n/a"
print(f"{r['name']:<22} {r['task']:<18} {actual:>10}")
print("=" * 60)
def main():
cfg = load_config()
model_dir = REPO_ROOT / cfg["model_dir"]
data_dir = REPO_ROOT / cfg["data_dir"]
# 1. Ensure checkpoints + data are present (download from HF on first run).
ensure_dir_populated(model_dir, cfg["hf_checkpoints_repo"], repo_type="model")
ensure_dir_populated(data_dir, cfg["hf_dataset_repo"], repo_type="dataset")
# 2. Run each evaluation. Sleep briefly between subprocesses so the GPU
# driver fully releases state before the next evaluate.py spawns —
# consecutive CUDA loads without a pause have caused numpy/torch
# import failures on this cluster.
import time
results = []
for i, entry in enumerate(cfg["evaluations"]):
if i > 0:
time.sleep(5)
ckpt = model_dir / entry["checkpoint_subpath"]
actual = evaluate_checkpoint(REPO_ROOT / entry["model_code_dir"], ckpt)
results.append({
"name": entry["name"],
"task": entry["task"],
"expected": entry["expected_accuracy"],
"actual": actual,
})
# 3. Summary.
print_results(results)
# Exit non-zero if any eval failed (so the grader / CI can detect).
if any(r["actual"] is None for r in results):
sys.exit(1)
if __name__ == "__main__":
main()