Skip to content

Commit 95c057d

Browse files
pre-commit-ci[bot]luciaquirke
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1d5b4cf commit 95c057d

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

tests/test_batch_size_invariance.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
import subprocess
9-
from pathlib import Path
109

1110
import pytest
1211
import torch
@@ -15,7 +14,7 @@
1514
from bergson.data import load_gradients
1615

1716

18-
@pytest.mark.slow
17+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
1918
@pytest.mark.parametrize("batch_size_a,batch_size_b", [(100, 100), (50, 150)])
2019
def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
2120
"""
@@ -29,8 +28,13 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
2928
gradient scales invariant to batch size.
3029
"""
3130
# Create two simple datasets
32-
texts_a = [f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)]
33-
texts_b = [f"A journey of a thousand miles begins with a single step {i}" for i in range(batch_size_b)]
31+
texts_a = [
32+
f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)
33+
]
34+
texts_b = [
35+
f"A journey of a thousand miles begins with a single step {i}"
36+
for i in range(batch_size_b)
37+
]
3438

3539
ds_a = Dataset.from_dict({"text": texts_a})
3640
ds_b = Dataset.from_dict({"text": texts_b})
@@ -50,63 +54,57 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
5054
def run_bergson_build(index_name: str, dataset_path: str):
5155
index_path = index_dir / index_name
5256
cmd = [
53-
"bergson", "build", str(index_path),
54-
"--model", "gpt2", # Use small model for testing
55-
"--dataset", dataset_path,
56-
"--prompt_column", "text",
57-
"--projection_dim", "8", # Small for speed
58-
"--token_batch_size", "1000",
57+
"bergson",
58+
"build",
59+
str(index_path),
60+
"--model",
61+
"gpt2", # Use small model for testing
62+
"--dataset",
63+
dataset_path,
64+
"--prompt_column",
65+
"text",
66+
"--projection_dim",
67+
"8", # Small for speed
68+
"--token_batch_size",
69+
"1000",
70+
"--nproc_per_node",
71+
"1",
5972
]
6073
subprocess.run(cmd, check=True, capture_output=True)
6174
return index_path
6275

6376
# Build indices
64-
index_a = run_bergson_build("a", str(data_dir / "data_a"))
65-
index_b = run_bergson_build("b", str(data_dir / "data_b"))
66-
index_combined = run_bergson_build("combined", str(data_dir / "data_combined"))
77+
index_a_path = run_bergson_build("a", str(data_dir / "data_a"))
78+
index_b_path = run_bergson_build("b", str(data_dir / "data_b"))
79+
index_combined_path = run_bergson_build("combined", str(data_dir / "data_combined"))
6780

6881
# Load gradients
6982
grads_a = torch.from_numpy(
70-
load_gradients(index_a, structured=False).copy()
83+
load_gradients(index_a_path, structured=False).copy()
7184
).float()
7285
grads_b = torch.from_numpy(
73-
load_gradients(index_b, structured=False).copy()
86+
load_gradients(index_b_path, structured=False).copy()
7487
).float()
7588
grads_combined = torch.from_numpy(
76-
load_gradients(index_combined, structured=False).copy()
89+
load_gradients(index_combined_path, structured=False).copy()
7790
).float()
7891

7992
# Split combined to match a and b
8093
grads_a_in_combined = grads_combined[:batch_size_a]
8194
grads_b_in_combined = grads_combined[batch_size_a:]
8295

8396
# Compute standard deviations
84-
std_a_sep = grads_a.std().item()
85-
std_a_comb = grads_a_in_combined.std().item()
86-
std_b_sep = grads_b.std().item()
87-
std_b_comb = grads_b_in_combined.std().item()
88-
89-
# With the fix (sum instead of mean), the standard deviations should be very close
90-
# We allow 20% tolerance to account for numerical noise and outliers
91-
ratio_a = std_a_sep / std_a_comb if std_a_comb > 0 else float('inf')
92-
ratio_b = std_b_sep / std_b_comb if std_b_comb > 0 else float('inf')
93-
94-
# Before the fix, these ratios could be 6x or more different
95-
# After the fix, they should be close to 1.0
96-
assert 0.8 <= ratio_a <= 1.2, (
97-
f"Gradient scales for dataset A differ too much between separate and combined: "
98-
f"ratio = {ratio_a:.2f}x (std_sep={std_a_sep:.2e}, std_comb={std_a_comb:.2e})"
99-
)
100-
assert 0.8 <= ratio_b <= 1.2, (
101-
f"Gradient scales for dataset B differ too much between separate and combined: "
102-
f"ratio = {ratio_b:.2f}x (std_sep={std_b_sep:.2e}, std_comb={std_b_comb:.2e})"
103-
)
97+
std_a_sep = grads_a.std()
98+
std_a_comb = grads_a_in_combined.std()
99+
std_b_sep = grads_b.std()
100+
std_b_comb = grads_b_in_combined.std()
101+
102+
torch.testing.assert_close(std_a_sep, std_a_comb)
103+
torch.testing.assert_close(std_b_sep, std_b_comb)
104104

105105
# Also check that cosine similarity is high (gradients point in the same direction)
106106
a_norm = grads_a / grads_a.norm(dim=1, keepdim=True)
107107
a_comb_norm = grads_a_in_combined / grads_a_in_combined.norm(dim=1, keepdim=True)
108108
cosines = (a_norm * a_comb_norm).sum(dim=1)
109109

110-
assert cosines.mean() > 0.99, (
111-
f"Gradients should point in the same direction: cosine similarity = {cosines.mean():.4f}"
112-
)
110+
torch.testing.assert_close(cosines.mean(), torch.tensor(1.0))

0 commit comments

Comments
 (0)