-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathtest_batch_size_invariance.py
More file actions
110 lines (91 loc) · 3.71 KB
/
test_batch_size_invariance.py
File metadata and controls
110 lines (91 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Test that gradient scales are invariant to batch size.
This test verifies the fix for issue #112 where gradients would vary in scale
depending on whether data was processed separately or together.
"""
import subprocess
import pytest
import torch
from datasets import Dataset
from bergson.data import load_gradients
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("batch_size_a,batch_size_b", [(100, 100), (50, 150)])
def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
"""
Test that gradient scales don't depend on how we batch the data.
This reproduces the bug from issue #112: when computing gradients for the same
set of datapoints, the gradient magnitudes should be consistent regardless of
whether we process them separately or together.
The fix changes loss.mean().backward() to loss.sum().backward() to make
gradient scales invariant to batch size.
"""
# Create two simple datasets
texts_a = [
f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)
]
texts_b = [
f"A journey of a thousand miles begins with a single step {i}"
for i in range(batch_size_b)
]
ds_a = Dataset.from_dict({"text": texts_a})
ds_b = Dataset.from_dict({"text": texts_b})
ds_combined = Dataset.from_dict({"text": texts_a + texts_b})
# Save datasets
data_dir = tmp_path / "data"
data_dir.mkdir()
ds_a.save_to_disk(str(data_dir / "data_a"))
ds_b.save_to_disk(str(data_dir / "data_b"))
ds_combined.save_to_disk(str(data_dir / "data_combined"))
# Build three indices with minimal settings for speed
index_dir = tmp_path / "indices"
index_dir.mkdir()
def run_bergson_build(index_name: str, dataset_path: str):
index_path = index_dir / index_name
cmd = [
"bergson",
"build",
str(index_path),
"--model",
"gpt2", # Use small model for testing
"--dataset",
dataset_path,
"--prompt_column",
"text",
"--projection_dim",
"8", # Small for speed
"--token_batch_size",
"1000",
"--nproc_per_node",
"1",
]
subprocess.run(cmd, check=True, capture_output=True)
return index_path
# Build indices
index_a_path = run_bergson_build("a", str(data_dir / "data_a"))
index_b_path = run_bergson_build("b", str(data_dir / "data_b"))
index_combined_path = run_bergson_build("combined", str(data_dir / "data_combined"))
# Load gradients
grads_a = torch.from_numpy(
load_gradients(index_a_path, structured=False).copy()
).float()
grads_b = torch.from_numpy(
load_gradients(index_b_path, structured=False).copy()
).float()
grads_combined = torch.from_numpy(
load_gradients(index_combined_path, structured=False).copy()
).float()
# Split combined to match a and b
grads_a_in_combined = grads_combined[:batch_size_a]
grads_b_in_combined = grads_combined[batch_size_a:]
# Compute standard deviations
std_a_sep = grads_a.std()
std_a_comb = grads_a_in_combined.std()
std_b_sep = grads_b.std()
std_b_comb = grads_b_in_combined.std()
torch.testing.assert_close(std_a_sep, std_a_comb)
torch.testing.assert_close(std_b_sep, std_b_comb)
# Also check that cosine similarity is high (gradients point in the same direction)
a_norm = grads_a / grads_a.norm(dim=1, keepdim=True)
a_comb_norm = grads_a_in_combined / grads_a_in_combined.norm(dim=1, keepdim=True)
cosines = (a_norm * a_comb_norm).sum(dim=1)
torch.testing.assert_close(cosines.mean(), torch.tensor(1.0))