@@ -41,7 +41,7 @@ def _rand_complex(shape, seed_):
4141
4242
4343def test_random_d_r (nx ):
44- """Sample d and r uniformly and run cost (and metric when available) with those shapes."""
44+ """Sample d and r uniformly and run sgot_cost_matrix (and sgot_metric when available) with those shapes."""
4545 rng = np .random .RandomState (0 )
4646 d_min , d_max = 4 , 12
4747 r_min , r_max = 2 , 6
@@ -67,14 +67,22 @@ def test_random_d_r(nx):
6767# ---------------------------------------------------------------------
6868
6969
70- def test_delta_identity ():
70+ def test_eigenvalue_cost_matrix_simple ():
71+ Ds = np .array ([0.0 , 1.0 ])
72+ Dt = np .array ([0.0 , 2.0 ])
73+ C = eigenvalue_cost_matrix (Ds , Dt , q = 2 )
74+ expected = np .array ([[0.0 , 4.0 ], [1.0 , 1.0 ]])
75+ np .testing .assert_allclose (C , expected )
76+
77+
78+ def test_delta_matrix_1d_identity ():
7179 r = 4
7280 I = np .eye (r , dtype = complex )
7381 delta = _delta_matrix_1d (I , I , I , I )
7482 np .testing .assert_allclose (delta , np .eye (r ), atol = 1e-12 )
7583
7684
77- def test_delta_swap_invariance ():
85+ def test_delta_matrix_1d_swap_invariance ():
7886 d , r = 6 , 3
7987 _ , R , _ , _ , _ , _ = random_atoms (d = d , r = r )
8088 L = R .copy ()
@@ -98,7 +106,7 @@ def test_grassmann_zero_distance(grassman_metric, nx):
98106 np .testing .assert_allclose (dist2_np , 0.0 , atol = 1e-12 )
99107
100108
101- def test_grassmann_invalid_name ():
109+ def test_grassmann_distance_invalid_name ():
102110 delta = np .ones ((2 , 2 ))
103111 with pytest .raises (ValueError ):
104112 _grassmann_distance_squared (delta , grassman_metric = "cordal" )
@@ -110,7 +118,7 @@ def test_grassmann_invalid_name():
110118
111119
112120def test_cost_self_zero (nx ):
113- """(D_S R_S L_S D_S): diagonal of cost matrix (same atom to same atom) should be near zero."""
121+ """(D_S R_S L_S D_S): diagonal of sgot_cost_matrix matrix (same atom to same atom) should be near zero."""
114122 Ds , Rs , Ls , _ , _ , _ = random_atoms ()
115123 Ds_b , Rs_b , Ls_b , Ds_b2 , Rs_b2 , Ls_b2 = nx .from_numpy (Ds , Rs , Ls , Ds , Rs , Ls )
116124 C = sgot_cost_matrix (Ds_b , Rs_b , Ls_b , Ds_b2 , Rs_b2 , Ls_b2 )
@@ -119,7 +127,7 @@ def test_cost_self_zero(nx):
119127 np .testing .assert_allclose (C_np , C_np .T , atol = 1e-10 )
120128
121129
122- def test_cost_reference (nx ):
130+ def test_grassmann_cost_reference (nx ):
123131 """Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose)."""
124132 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
125133 Ds_b , Rs_b , Ls_b , Dt_b , Rt_b , Lt_b = nx .from_numpy (Ds , Rs , Ls , Dt , Rt , Lt )
@@ -132,7 +140,7 @@ def test_cost_reference(nx):
132140@pytest .mark .parametrize (
133141 "grassman_metric" , ["geodesic" , "chordal" , "procrustes" , "martin" ]
134142)
135- def test_cost_basic (grassman_metric , nx ):
143+ def test_grassmann_cost_basic_properties (grassman_metric , nx ):
136144 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
137145 Ds_b , Rs_b , Ls_b , Dt_b , Rt_b , Lt_b = nx .from_numpy (Ds , Rs , Ls , Dt , Rt , Lt )
138146 C = sgot_cost_matrix (
@@ -144,7 +152,7 @@ def test_cost_basic(grassman_metric, nx):
144152 assert np .all (C_np >= 0 )
145153
146154
147- def test_cost_validation ():
155+ def test_sgot_cost_input_validation ():
148156 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
149157
150158 with pytest .raises (ValueError ):
@@ -159,21 +167,21 @@ def test_cost_validation():
159167# ---------------------------------------------------------------------
160168
161169
162- def test_metric_self_zero ():
170+ def test_sgot_metric_self_zero ():
163171 Ds , Rs , Ls , _ , _ , _ = random_atoms ()
164172 dist = sgot_metric (Ds , Rs , Ls , Ds , Rs , Ls )
165173 assert np .isfinite (dist )
166174 assert abs (dist ) < 5e-4
167175
168176
169- def test_metric_symmetry ():
177+ def test_sgot_metric_symmetry ():
170178 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
171179 d1 = sgot_metric (Ds , Rs , Ls , Dt , Rt , Lt )
172180 d2 = sgot_metric (Dt , Rt , Lt , Ds , Rs , Ls )
173181 np .testing .assert_allclose (d1 , d2 , atol = 1e-8 )
174182
175183
176- def test_metric_with_weights ():
184+ def test_sgot_metric_with_weights ():
177185 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
178186 r = Ds .shape [0 ]
179187
@@ -193,16 +201,20 @@ def test_metric_with_weights():
193201# ---------------------------------------------------------------------
194202
195203
196- def test_hyperparameter_sweep_cost (nx ):
197- """Sweep over a random set of HPs and run cost()."""
198- grassmann_types = ["geodesic" , "chordal" , "procrustes" , "martin" ]
204+ @pytest .mark .parametrize (
205+ "eta, p, q, grassman_metric" ,
206+ [
207+ (0.5 , 1 , 1 , "geodesic" ),
208+ (0.5 , 2 , 1 , "chordal" ),
209+ (0.3 , 2 , 2 , "procrustes" ),
210+ (0.7 , 1 , 2 , "martin" ),
211+ ],
212+ )
213+ def test_hyperparameter_sweep_cost (nx , eta , p , q , grassman_metric ):
214+ """Sweep over a set of fixed HPs and run cost()."""
199215 Ds , Rs , Ls , Dt , Rt , Lt = random_atoms ()
200216 Ds_b , Rs_b , Ls_b , Dt_b , Rt_b , Lt_b = nx .from_numpy (Ds , Rs , Ls , Dt , Rt , Lt )
201- rng = np .random .RandomState (2 )
202- eta = rng .uniform (0.0 , 1.0 )
203- p = rng .choice ([1 , 2 ])
204- q = rng .choice ([1 , 2 ])
205- gm = rng .choice (grassmann_types )
217+
206218 C = sgot_cost_matrix (
207219 Ds_b ,
208220 Rs_b ,
@@ -213,7 +225,7 @@ def test_hyperparameter_sweep_cost(nx):
213225 eta = eta ,
214226 p = p ,
215227 q = q ,
216- grassman_metric = gm ,
228+ grassman_metric = grassman_metric ,
217229 )
218230 C_np = nx .to_numpy (C )
219231 assert C_np .shape == (Ds .shape [0 ], Dt .shape [0 ])
0 commit comments