Skip to content

Commit 27b6917

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 562f5e5 commit 27b6917

File tree

8 files changed

+30
-21
lines changed

8 files changed

+30
-21
lines changed

tests/ekfac_tests/compute_ekfac_ground_truth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import TYPE_CHECKING, Any, Optional
2121

2222
import torch
23-
import torch.distributed as dist
2423
import torch.nn.functional as F
2524
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset
2625
from ground_truth.collector import (

tests/ekfac_tests/conftest.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,18 @@ def ekfac_results_path(
261261
test_dir: str,
262262
ground_truth_path: str,
263263
ground_truth_setup: dict[str, Any],
264-
overwrite: bool
264+
overwrite: bool,
265265
) -> str:
266266
"""Run EKFAC computation and return results path.
267267
268268
Uses the same data and batches as ground truth via collect_hessians to ensure
269269
identical batch composition and floating-point accumulation order.
270270
"""
271271
import torch
272+
272273
from bergson.config import HessianConfig
273-
from bergson.hessians.hessian_approximations import collect_hessians
274274
from bergson.hessians.eigenvectors import compute_eigendecomposition
275+
from bergson.hessians.hessian_approximations import collect_hessians
275276

276277
# collect_hessians writes to partial_run_path (run_path + ".part")
277278
# We set run_path so partial_run_path points to our desired output location
@@ -301,7 +302,9 @@ def ekfac_results_path(
301302
cfg.run_path = base_run_path
302303
cfg.partial_run_path.mkdir(parents=True, exist_ok=True)
303304

304-
hessian_cfg = HessianConfig(method="kfac", ev_correction=True, use_dataset_labels=True)
305+
hessian_cfg = HessianConfig(
306+
method="kfac", ev_correction=True, use_dataset_labels=True
307+
)
305308

306309
# Phase 1: Covariance collection using collect_hessians
307310
collect_hessians(

tests/ekfac_tests/test_batch_size_invariance.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ def compute_traces(batches: list[list[int]]) -> tuple[float, float]:
8484
computer.run_with_collector_hooks()
8585

8686
# Load covariances
87-
A = load_sharded_covariances(index_cfg.partial_run_path / "activation_sharded")
88-
G = load_sharded_covariances(index_cfg.partial_run_path / "gradient_sharded")
87+
A = load_sharded_covariances(
88+
index_cfg.partial_run_path / "activation_sharded"
89+
)
90+
G = load_sharded_covariances(
91+
index_cfg.partial_run_path / "gradient_sharded"
92+
)
8993
n = torch.load(index_cfg.partial_run_path / "total_processed.pt").item()
9094

9195
return (

tests/ekfac_tests/test_compute_ekfac.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ def test_total_processed_examples(
1313
total_processed_ground_truth_path = os.path.join(
1414
ground_truth_covariances_path, "stats.json"
1515
)
16-
total_processed_run_path = os.path.join(
17-
ekfac_results_path, "total_processed.pt"
18-
)
16+
total_processed_run_path = os.path.join(ekfac_results_path, "total_processed.pt")
1917

2018
with open(total_processed_ground_truth_path, "r") as f:
2119
ground_truth_data = json.load(f)

tests/ekfac_tests/test_covariance.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def test_covariances(
4747
if all_match:
4848
print(f"{covariance_type} covariances match within tolerance (rtol={rtol})")
4949
else:
50-
error_msg = f"{covariance_type} covariances do not match (rtol={rtol})!\n" + "\n".join(
51-
error_details
50+
error_msg = (
51+
f"{covariance_type} covariances do not match (rtol={rtol})!\n"
52+
+ "\n".join(error_details)
5253
)
5354
assert False, error_msg
5455

tests/ekfac_tests/test_eigenvalue_correction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ def test_eigenvalue_corrections(
2929
# load run eigenvalue corrections (sharded)
3030
lambda_run = load_sharded_covariances(lambda_run_path)
3131

32-
total_processed_run_path = os.path.join(
33-
ekfac_results_path, "total_processed.pt"
34-
)
32+
total_processed_run_path = os.path.join(ekfac_results_path, "total_processed.pt")
3533
lambda_device = lambda_run[list(lambda_run.keys())[0]].device
3634
total = torch.load(total_processed_run_path, map_location=lambda_device)
3735

@@ -78,7 +76,9 @@ def test_eigenvalue_corrections(
7876
if all_match:
7977
print(f"Eigenvalue corrections match within tolerance (rtol={rtol})")
8078
elif has_significant_errors:
81-
error_msg = f"Eigenvalue corrections do not match (rtol={rtol})!\n" + "\n".join(error_details)
79+
error_msg = f"Eigenvalue corrections do not match (rtol={rtol})!\n" + "\n".join(
80+
error_details
81+
)
8282
assert False, error_msg
8383
else:
8484
print("Eigenvalue corrections: all differences within tolerance")

tests/ekfac_tests/test_fim_accuracy.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,15 @@ def test_kfac_fim_accuracy(seq_lengths, num_batches, max_rel_error, sample, tmp_
164164
computer.forward_backward = fwd_bwd_hessian_factory(index_cfg, hessian_cfg)
165165
computer.run_with_collector_hooks()
166166

167-
A_dict_kfac = load_sharded_covariances(index_cfg.partial_run_path / "activation_sharded")
168-
G_dict_kfac = load_sharded_covariances(index_cfg.partial_run_path / "gradient_sharded")
169-
total_processed_kfac = torch.load(index_cfg.partial_run_path / "total_processed.pt").item()
167+
A_dict_kfac = load_sharded_covariances(
168+
index_cfg.partial_run_path / "activation_sharded"
169+
)
170+
G_dict_kfac = load_sharded_covariances(
171+
index_cfg.partial_run_path / "gradient_sharded"
172+
)
173+
total_processed_kfac = torch.load(
174+
index_cfg.partial_run_path / "total_processed.pt"
175+
).item()
170176

171177
assert total_processed_kfac == total_processed_exact
172178

tests/ekfac_tests/test_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from torch import Tensor
1111

1212

13-
def add_tensor_dicts(
14-
a: dict[str, Tensor], b: dict[str, Tensor]
15-
) -> dict[str, Tensor]:
13+
def add_tensor_dicts(a: dict[str, Tensor], b: dict[str, Tensor]) -> dict[str, Tensor]:
1614
"""Add two dictionaries of tensors element-wise."""
1715
assert set(a.keys()) == set(b.keys()), "Keys must match"
1816
return {k: a[k] + b[k] for k in a}

0 commit comments

Comments
 (0)