Skip to content

Commit 8c170a5

Browse files
smarterLouisYRYJ
andcommitted
Add ground truth ekfac tests
This is still missing FSDP support and test_apply_ekfac.py from #68 Co-Authored-By: LouisYRYJ <louis.yousif@yahoo.de>
1 parent a7e183a commit 8c170a5

File tree

9 files changed

+1527
-0
lines changed

9 files changed

+1527
-0
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 787 additions & 0 deletions
Large diffs are not rendered by default.

tests/ekfac_tests/conftest.py

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
"""Pytest configuration and fixtures for EKFAC tests."""
2+
3+
import os
4+
from typing import Any, Optional
5+
6+
import pytest
7+
from compute_ekfac_ground_truth import (
8+
combine_covariances_step,
9+
combine_eigenvalue_corrections_step,
10+
compute_covariances_step,
11+
compute_eigenvalue_corrections_step,
12+
compute_eigenvectors_step,
13+
load_dataset_step,
14+
load_model_step,
15+
setup_paths_and_config,
16+
tokenize_and_allocate_step,
17+
)
18+
from test_utils import set_all_seeds
19+
20+
Precision = str # Type alias for precision strings
21+
22+
23+
def pytest_addoption(parser) -> None:
24+
"""Add custom command-line options for EKFAC tests."""
25+
parser.addoption(
26+
"--model_name",
27+
action="store",
28+
type=str,
29+
default="EleutherAI/Pythia-14m",
30+
help="Model name for ground truth generation (default: EleutherAI/Pythia-14m)",
31+
)
32+
parser.addoption(
33+
"--overwrite",
34+
action="store_true",
35+
default=False,
36+
help="Overwrite existing run directory",
37+
)
38+
parser.addoption(
39+
"--precision",
40+
action="store",
41+
type=str,
42+
default="fp32",
43+
choices=["fp32", "fp16", "bf16", "int4", "int8"],
44+
help="Model precision for ground truth generation (default: fp32)",
45+
)
46+
parser.addoption(
47+
"--test_dir",
48+
action="store",
49+
default=None,
50+
help="Directory containing test data. If not provided, generates data.",
51+
)
52+
parser.addoption(
53+
"--world_size",
54+
action="store",
55+
type=int,
56+
default=1,
57+
help="World size for distributed training (default: 1)",
58+
)
59+
60+
61+
@pytest.fixture(autouse=True)
62+
def setup_test() -> None:
63+
"""Setup logic run before each test."""
64+
set_all_seeds(seed=42)
65+
66+
67+
@pytest.fixture(scope="session")
68+
def gradient_batch_size(request) -> int:
69+
return request.config.getoption("--gradient_batch_size")
70+
71+
72+
@pytest.fixture(scope="session")
73+
def gradient_path(request) -> Optional[str]:
74+
return request.config.getoption("--gradient_path")
75+
76+
77+
@pytest.fixture(scope="session")
78+
def model_name(request) -> str:
79+
return request.config.getoption("--model_name")
80+
81+
82+
@pytest.fixture(scope="session")
83+
def overwrite(request) -> bool:
84+
return request.config.getoption("--overwrite")
85+
86+
87+
@pytest.fixture(scope="session")
88+
def precision(request) -> Precision:
89+
return request.config.getoption("--precision")
90+
91+
92+
@pytest.fixture(scope="session")
93+
def use_fsdp(request) -> bool:
94+
return request.config.getoption("--use_fsdp")
95+
96+
97+
@pytest.fixture(scope="session")
98+
def world_size(request) -> int:
99+
return request.config.getoption("--world_size")
100+
101+
102+
@pytest.fixture(scope="session")
103+
def test_dir(request, tmp_path_factory) -> str:
104+
"""Get or create test directory (does not generate ground truth data)."""
105+
# Check if test directory was provided
106+
test_dir = request.config.getoption("--test_dir")
107+
if test_dir is not None:
108+
return test_dir
109+
110+
# Create temporary directory for auto-generated test data
111+
tmp_dir = tmp_path_factory.mktemp("ekfac_test_data")
112+
return str(tmp_dir)
113+
114+
115+
def ground_truth_base_path(test_dir: str) -> str:
116+
return os.path.join(test_dir, "ground_truth")
117+
118+
119+
@pytest.fixture(scope="session")
120+
def ground_truth_setup(
121+
request, test_dir: str, precision: Precision, overwrite: bool
122+
) -> dict[str, Any]:
123+
# Setup for generation
124+
model_name = request.config.getoption("--model_name")
125+
world_size = request.config.getoption("--world_size")
126+
127+
print(f"\n{'='*60}")
128+
print("Generating ground truth test data")
129+
print(f"Model: {model_name}")
130+
print(f"Precision: {precision}")
131+
print(f"World size: {world_size}")
132+
print(f"{'='*60}\n")
133+
134+
cfg, workers, device, target_modules, dtype = setup_paths_and_config(
135+
precision=precision,
136+
test_path=ground_truth_base_path(test_dir),
137+
model_name=model_name,
138+
world_size=world_size,
139+
overwrite=overwrite,
140+
)
141+
142+
model = load_model_step(cfg, dtype)
143+
model.eval() # Disable dropout for deterministic forward passes
144+
ds = load_dataset_step(cfg)
145+
data, batches_world, tokenizer = tokenize_and_allocate_step(ds, cfg, workers)
146+
147+
return {
148+
"cfg": cfg,
149+
"workers": workers,
150+
"device": device,
151+
"target_modules": target_modules,
152+
"dtype": dtype,
153+
"model": model,
154+
"data": data,
155+
"batches_world": batches_world,
156+
}
157+
158+
159+
@pytest.fixture(scope="session")
160+
def ground_truth_covariances_path(
161+
ground_truth_setup: dict[str, Any], test_dir: str, overwrite: bool
162+
) -> str:
163+
"""Ensure ground truth covariances exist and return path."""
164+
base_path = ground_truth_base_path(test_dir)
165+
covariances_path = os.path.join(base_path, "covariances")
166+
167+
if os.path.exists(covariances_path) and not overwrite:
168+
print("Using existing covariances")
169+
return covariances_path
170+
171+
setup = ground_truth_setup
172+
# Reset seeds for deterministic computation (same seed as EKFAC will use)
173+
set_all_seeds(42)
174+
covariance_test_path = compute_covariances_step(
175+
setup["model"],
176+
setup["data"],
177+
setup["batches_world"],
178+
setup["device"],
179+
setup["target_modules"],
180+
setup["workers"],
181+
base_path,
182+
)
183+
combine_covariances_step(covariance_test_path, setup["workers"], setup["device"])
184+
print("Covariances computed")
185+
return covariances_path
186+
187+
188+
@pytest.fixture(scope="session")
189+
def ground_truth_eigenvectors_path(
190+
ground_truth_covariances_path: str,
191+
ground_truth_setup: dict[str, Any],
192+
test_dir: str,
193+
overwrite: bool,
194+
) -> str:
195+
"""Ensure ground truth eigenvectors exist and return path."""
196+
base_path = ground_truth_base_path(test_dir)
197+
eigenvectors_path = os.path.join(base_path, "eigenvectors")
198+
199+
if os.path.exists(eigenvectors_path) and not overwrite:
200+
print("Using existing eigenvectors")
201+
return eigenvectors_path
202+
203+
setup = ground_truth_setup
204+
compute_eigenvectors_step(base_path, setup["device"], setup["dtype"])
205+
print("Eigenvectors computed")
206+
return eigenvectors_path
207+
208+
209+
@pytest.fixture(scope="session")
210+
def ground_truth_eigenvalue_corrections_path(
211+
ground_truth_eigenvectors_path: str,
212+
ground_truth_setup: dict[str, Any],
213+
test_dir: str,
214+
overwrite: bool,
215+
) -> str:
216+
"""Ensure ground truth eigenvalue corrections exist and return path."""
217+
base_path = ground_truth_base_path(test_dir)
218+
eigenvalue_corrections_path = os.path.join(base_path, "eigenvalue_corrections")
219+
220+
if os.path.exists(eigenvalue_corrections_path) and not overwrite:
221+
print("Using existing eigenvalue corrections")
222+
return eigenvalue_corrections_path
223+
224+
setup = ground_truth_setup
225+
eigenvalue_correction_test_path, total_processed_global_lambda = (
226+
compute_eigenvalue_corrections_step(
227+
setup["model"],
228+
setup["data"],
229+
setup["batches_world"],
230+
setup["device"],
231+
setup["target_modules"],
232+
setup["workers"],
233+
base_path,
234+
)
235+
)
236+
combine_eigenvalue_corrections_step(
237+
eigenvalue_correction_test_path,
238+
setup["workers"],
239+
setup["device"],
240+
total_processed_global_lambda,
241+
)
242+
print("Eigenvalue corrections computed")
243+
print("\n=== Ground Truth Computation Complete ===")
244+
print(f"Results saved to: {base_path}")
245+
return eigenvalue_corrections_path
246+
247+
248+
@pytest.fixture(scope="session")
249+
def ground_truth_path(
250+
ground_truth_eigenvalue_corrections_path: str, test_dir: str
251+
) -> str:
252+
"""Get ground truth base path with all data guaranteed to exist.
253+
254+
Depends on ground_truth_eigenvalue_corrections_path to ensure all
255+
ground truth data exists.
256+
"""
257+
return ground_truth_base_path(test_dir)
258+
259+
260+
@pytest.fixture(scope="session")
261+
def ekfac_results_path(
262+
test_dir: str,
263+
ground_truth_path: str,
264+
ground_truth_setup: dict[str, Any],
265+
overwrite: bool,
266+
) -> str:
267+
"""Run EKFAC computation and return results path.
268+
269+
Uses the same data and batches as ground truth via collect_hessians to ensure
270+
identical batch composition and floating-point accumulation order.
271+
"""
272+
import torch
273+
274+
from bergson.config import HessianConfig
275+
from bergson.hessians.eigenvectors import compute_eigendecomposition
276+
from bergson.hessians.hessian_approximations import collect_hessians
277+
278+
# collect_hessians writes to partial_run_path (run_path + ".part")
279+
# We set run_path so partial_run_path points to our desired output location
280+
base_run_path = os.path.join(test_dir, "run/kfac")
281+
results_path = base_run_path + ".part" # Where collect_hessians will write
282+
283+
if os.path.exists(results_path) and not overwrite:
284+
print(f"Using existing EKFAC results in {results_path}")
285+
return results_path
286+
287+
setup = ground_truth_setup
288+
cfg = setup["cfg"]
289+
data = setup["data"]
290+
batches = setup["batches_world"][0] # Single worker
291+
target_modules = setup["target_modules"]
292+
dtype = setup["dtype"]
293+
294+
print(f"\nRunning EKFAC computation in {results_path}...")
295+
296+
# Reset seeds for determinism (same as used before GT computation)
297+
set_all_seeds(42)
298+
299+
# Reload model to get fresh state (same as GT does)
300+
model = load_model_step(cfg, dtype)
301+
model.eval()
302+
303+
cfg.run_path = base_run_path
304+
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
305+
306+
hessian_cfg = HessianConfig(
307+
method="kfac", ev_correction=True, use_dataset_labels=True
308+
)
309+
310+
# Phase 1: Covariance collection using collect_hessians
311+
collect_hessians(
312+
model=model,
313+
data=data,
314+
index_cfg=cfg,
315+
batches=batches,
316+
target_modules=target_modules,
317+
hessian_cfg=hessian_cfg,
318+
)
319+
320+
total_processed = torch.load(
321+
os.path.join(results_path, "total_processed.pt"),
322+
map_location="cpu",
323+
weights_only=False,
324+
)
325+
326+
# Phase 2: Eigendecomposition
327+
compute_eigendecomposition(
328+
os.path.join(results_path, "activation_sharded"),
329+
total_processed=total_processed,
330+
)
331+
compute_eigendecomposition(
332+
os.path.join(results_path, "gradient_sharded"),
333+
total_processed=total_processed,
334+
)
335+
336+
# Phase 3: Eigenvalue correction
337+
collect_hessians(
338+
model=model,
339+
data=data,
340+
index_cfg=cfg,
341+
batches=batches,
342+
target_modules=target_modules,
343+
hessian_cfg=hessian_cfg,
344+
ev_correction=True,
345+
)
346+
347+
print(f"EKFAC computation completed in {results_path}")
348+
return results_path

tests/ekfac_tests/ground_truth/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)