Skip to content

Commit 13d720f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1d5b4cf commit 13d720f

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

tests/test_batch_size_invariance.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import subprocess
9-
from pathlib import Path
109

1110
import pytest
1211
import torch
@@ -29,8 +28,13 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
2928
gradient scales invariant to batch size.
3029
"""
3130
# Create two simple datasets
32-
texts_a = [f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)]
33-
texts_b = [f"A journey of a thousand miles begins with a single step {i}" for i in range(batch_size_b)]
31+
texts_a = [
32+
f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)
33+
]
34+
texts_b = [
35+
f"A journey of a thousand miles begins with a single step {i}"
36+
for i in range(batch_size_b)
37+
]
3438

3539
ds_a = Dataset.from_dict({"text": texts_a})
3640
ds_b = Dataset.from_dict({"text": texts_b})
@@ -50,12 +54,19 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
5054
def run_bergson_build(index_name: str, dataset_path: str):
5155
index_path = index_dir / index_name
5256
cmd = [
53-
"bergson", "build", str(index_path),
54-
"--model", "gpt2", # Use small model for testing
55-
"--dataset", dataset_path,
56-
"--prompt_column", "text",
57-
"--projection_dim", "8", # Small for speed
58-
"--token_batch_size", "1000",
57+
"bergson",
58+
"build",
59+
str(index_path),
60+
"--model",
61+
"gpt2", # Use small model for testing
62+
"--dataset",
63+
dataset_path,
64+
"--prompt_column",
65+
"text",
66+
"--projection_dim",
67+
"8", # Small for speed
68+
"--token_batch_size",
69+
"1000",
5970
]
6071
subprocess.run(cmd, check=True, capture_output=True)
6172
return index_path
@@ -66,12 +77,8 @@ def run_bergson_build(index_name: str, dataset_path: str):
6677
index_combined = run_bergson_build("combined", str(data_dir / "data_combined"))
6778

6879
# Load gradients
69-
grads_a = torch.from_numpy(
70-
load_gradients(index_a, structured=False).copy()
71-
).float()
72-
grads_b = torch.from_numpy(
73-
load_gradients(index_b, structured=False).copy()
74-
).float()
80+
grads_a = torch.from_numpy(load_gradients(index_a, structured=False).copy()).float()
81+
grads_b = torch.from_numpy(load_gradients(index_b, structured=False).copy()).float()
7582
grads_combined = torch.from_numpy(
7683
load_gradients(index_combined, structured=False).copy()
7784
).float()
@@ -88,8 +95,8 @@ def run_bergson_build(index_name: str, dataset_path: str):
8895

8996
# With the fix (sum instead of mean), the standard deviations should be very close
9097
# We allow 20% tolerance to account for numerical noise and outliers
91-
ratio_a = std_a_sep / std_a_comb if std_a_comb > 0 else float('inf')
92-
ratio_b = std_b_sep / std_b_comb if std_b_comb > 0 else float('inf')
98+
ratio_a = std_a_sep / std_a_comb if std_a_comb > 0 else float("inf")
99+
ratio_b = std_b_sep / std_b_comb if std_b_comb > 0 else float("inf")
93100

94101
# Before the fix, these ratios could be 6x or more different
95102
# After the fix, they should be close to 1.0
@@ -107,6 +114,6 @@ def run_bergson_build(index_name: str, dataset_path: str):
107114
a_comb_norm = grads_a_in_combined / grads_a_in_combined.norm(dim=1, keepdim=True)
108115
cosines = (a_norm * a_comb_norm).sum(dim=1)
109116

110-
assert cosines.mean() > 0.99, (
111-
f"Gradients should point in the same direction: cosine similarity = {cosines.mean():.4f}"
112-
)
117+
assert (
118+
cosines.mean() > 0.99
119+
), f"Gradients should point in the same direction: cosine similarity = {cosines.mean():.4f}"

0 commit comments

Comments
 (0)