Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Mark tests requiring GPUs with `@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")`.

Use `pre-commit run --all-files` if you forget to install pre-commit and it doesn't run in the hook.
2 changes: 1 addition & 1 deletion bergson/collector/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def fwd_bwd(model, batch):
if "advantage" in batch:
losses *= torch.tensor(batch["advantage"], device=losses.device)

losses.mean().backward()
losses.sum().backward()
model.zero_grad()

return losses
Expand Down
110 changes: 110 additions & 0 deletions tests/test_batch_size_invariance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Test that gradient scales are invariant to batch size.

This test verifies the fix for issue #112 where gradients would vary in scale
depending on whether data was processed separately or together.
"""

import subprocess

import pytest
import torch
from datasets import Dataset

from bergson.data import load_gradients


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("batch_size_a,batch_size_b", [(100, 100), (50, 150)])
def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
"""
Test that gradient scales don't depend on how we batch the data.

This reproduces the bug from issue #112: when computing gradients for the same
set of datapoints, the gradient magnitudes should be consistent regardless of
whether we process them separately or together.

The fix changes loss.mean().backward() to loss.sum().backward() to make
gradient scales invariant to batch size.
"""
# Create two simple datasets
texts_a = [
f"The quick brown fox jumps over the lazy dog {i}" for i in range(batch_size_a)
]
texts_b = [
f"A journey of a thousand miles begins with a single step {i}"
for i in range(batch_size_b)
]

ds_a = Dataset.from_dict({"text": texts_a})
ds_b = Dataset.from_dict({"text": texts_b})
ds_combined = Dataset.from_dict({"text": texts_a + texts_b})

# Save datasets
data_dir = tmp_path / "data"
data_dir.mkdir()
ds_a.save_to_disk(str(data_dir / "data_a"))
ds_b.save_to_disk(str(data_dir / "data_b"))
ds_combined.save_to_disk(str(data_dir / "data_combined"))

# Build three indices with minimal settings for speed
index_dir = tmp_path / "indices"
index_dir.mkdir()

def run_bergson_build(index_name: str, dataset_path: str):
index_path = index_dir / index_name
cmd = [
"bergson",
"build",
str(index_path),
"--model",
"gpt2", # Use small model for testing
"--dataset",
dataset_path,
"--prompt_column",
"text",
"--projection_dim",
"8", # Small for speed
"--token_batch_size",
"1000",
"--nproc_per_node",
"1",
]
subprocess.run(cmd, check=True, capture_output=True)
return index_path

# Build indices
index_a_path = run_bergson_build("a", str(data_dir / "data_a"))
index_b_path = run_bergson_build("b", str(data_dir / "data_b"))
index_combined_path = run_bergson_build("combined", str(data_dir / "data_combined"))

# Load gradients
grads_a = torch.from_numpy(
load_gradients(index_a_path, structured=False).copy()
).float()
grads_b = torch.from_numpy(
load_gradients(index_b_path, structured=False).copy()
).float()
grads_combined = torch.from_numpy(
load_gradients(index_combined_path, structured=False).copy()
).float()

# Split combined to match a and b
grads_a_in_combined = grads_combined[:batch_size_a]
grads_b_in_combined = grads_combined[batch_size_a:]

# Compute standard deviations
std_a_sep = grads_a.std()
std_a_comb = grads_a_in_combined.std()
std_b_sep = grads_b.std()
std_b_comb = grads_b_in_combined.std()

torch.testing.assert_close(std_a_sep, std_a_comb)
torch.testing.assert_close(std_b_sep, std_b_comb)

# Also check that cosine similarity is high (gradients point in the same direction)
a_norm = grads_a / grads_a.norm(dim=1, keepdim=True)
a_comb_norm = grads_a_in_combined / grads_a_in_combined.norm(dim=1, keepdim=True)
cosines = (a_norm * a_comb_norm).sum(dim=1)

torch.testing.assert_close(cosines.mean(), torch.tensor(1.0))