66"""
77
88import subprocess
9- from pathlib import Path
109
1110import pytest
1211import torch
1514from bergson .data import load_gradients
1615
1716
18- @pytest .mark .slow
17+ @pytest .mark .skipif ( not torch . cuda . is_available (), reason = "CUDA not available" )
1918@pytest .mark .parametrize ("batch_size_a,batch_size_b" , [(100 , 100 ), (50 , 150 )])
2019def test_gradient_scale_invariance (tmp_path , batch_size_a , batch_size_b ):
2120 """
@@ -29,8 +28,13 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
2928 gradient scales invariant to batch size.
3029 """
3130 # Create two simple datasets
32- texts_a = [f"The quick brown fox jumps over the lazy dog { i } " for i in range (batch_size_a )]
33- texts_b = [f"A journey of a thousand miles begins with a single step { i } " for i in range (batch_size_b )]
31+ texts_a = [
32+ f"The quick brown fox jumps over the lazy dog { i } " for i in range (batch_size_a )
33+ ]
34+ texts_b = [
35+ f"A journey of a thousand miles begins with a single step { i } "
36+ for i in range (batch_size_b )
37+ ]
3438
3539 ds_a = Dataset .from_dict ({"text" : texts_a })
3640 ds_b = Dataset .from_dict ({"text" : texts_b })
@@ -50,63 +54,57 @@ def test_gradient_scale_invariance(tmp_path, batch_size_a, batch_size_b):
5054 def run_bergson_build (index_name : str , dataset_path : str ):
5155 index_path = index_dir / index_name
5256 cmd = [
53- "bergson" , "build" , str (index_path ),
54- "--model" , "gpt2" , # Use small model for testing
55- "--dataset" , dataset_path ,
56- "--prompt_column" , "text" ,
57- "--projection_dim" , "8" , # Small for speed
58- "--token_batch_size" , "1000" ,
57+ "bergson" ,
58+ "build" ,
59+ str (index_path ),
60+ "--model" ,
61+ "gpt2" , # Use small model for testing
62+ "--dataset" ,
63+ dataset_path ,
64+ "--prompt_column" ,
65+ "text" ,
66+ "--projection_dim" ,
67+ "8" , # Small for speed
68+ "--token_batch_size" ,
69+ "1000" ,
70+ "--nproc_per_node" ,
71+ "1" ,
5972 ]
6073 subprocess .run (cmd , check = True , capture_output = True )
6174 return index_path
6275
6376 # Build indices
64- index_a = run_bergson_build ("a" , str (data_dir / "data_a" ))
65- index_b = run_bergson_build ("b" , str (data_dir / "data_b" ))
66- index_combined = run_bergson_build ("combined" , str (data_dir / "data_combined" ))
77+ index_a_path = run_bergson_build ("a" , str (data_dir / "data_a" ))
78+ index_b_path = run_bergson_build ("b" , str (data_dir / "data_b" ))
79+ index_combined_path = run_bergson_build ("combined" , str (data_dir / "data_combined" ))
6780
6881 # Load gradients
6982 grads_a = torch .from_numpy (
70- load_gradients (index_a , structured = False ).copy ()
83+ load_gradients (index_a_path , structured = False ).copy ()
7184 ).float ()
7285 grads_b = torch .from_numpy (
73- load_gradients (index_b , structured = False ).copy ()
86+ load_gradients (index_b_path , structured = False ).copy ()
7487 ).float ()
7588 grads_combined = torch .from_numpy (
76- load_gradients (index_combined , structured = False ).copy ()
89+ load_gradients (index_combined_path , structured = False ).copy ()
7790 ).float ()
7891
7992 # Split combined to match a and b
8093 grads_a_in_combined = grads_combined [:batch_size_a ]
8194 grads_b_in_combined = grads_combined [batch_size_a :]
8295
8396 # Compute standard deviations
84- std_a_sep = grads_a .std ().item ()
85- std_a_comb = grads_a_in_combined .std ().item ()
86- std_b_sep = grads_b .std ().item ()
87- std_b_comb = grads_b_in_combined .std ().item ()
88-
89- # With the fix (sum instead of mean), the standard deviations should be very close
90- # We allow 20% tolerance to account for numerical noise and outliers
91- ratio_a = std_a_sep / std_a_comb if std_a_comb > 0 else float ('inf' )
92- ratio_b = std_b_sep / std_b_comb if std_b_comb > 0 else float ('inf' )
93-
94- # Before the fix, these ratios could be 6x or more different
95- # After the fix, they should be close to 1.0
96- assert 0.8 <= ratio_a <= 1.2 , (
97- f"Gradient scales for dataset A differ too much between separate and combined: "
98- f"ratio = { ratio_a :.2f} x (std_sep={ std_a_sep :.2e} , std_comb={ std_a_comb :.2e} )"
99- )
100- assert 0.8 <= ratio_b <= 1.2 , (
101- f"Gradient scales for dataset B differ too much between separate and combined: "
102- f"ratio = { ratio_b :.2f} x (std_sep={ std_b_sep :.2e} , std_comb={ std_b_comb :.2e} )"
103- )
97+ std_a_sep = grads_a .std ()
98+ std_a_comb = grads_a_in_combined .std ()
99+ std_b_sep = grads_b .std ()
100+ std_b_comb = grads_b_in_combined .std ()
101+
102+ torch .testing .assert_close (std_a_sep , std_a_comb )
103+ torch .testing .assert_close (std_b_sep , std_b_comb )
104104
105105 # Also check that cosine similarity is high (gradients point in the same direction)
106106 a_norm = grads_a / grads_a .norm (dim = 1 , keepdim = True )
107107 a_comb_norm = grads_a_in_combined / grads_a_in_combined .norm (dim = 1 , keepdim = True )
108108 cosines = (a_norm * a_comb_norm ).sum (dim = 1 )
109109
110- assert cosines .mean () > 0.99 , (
111- f"Gradients should point in the same direction: cosine similarity = { cosines .mean ():.4f} "
112- )
110+ torch .testing .assert_close (cosines .mean (), torch .tensor (1.0 ))
0 commit comments