Skip to content

Commit e994303

Browse files
committed
correct issues on test_sgot
1 parent a0e74be commit e994303

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

test/test_sgot.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _rand_complex(shape, seed_):
4141

4242

4343
def test_random_d_r(nx):
44-
"""Sample d and r uniformly and run cost (and metric when available) with those shapes."""
44+
"""Sample d and r uniformly and run sgot_cost_matrix (and sgot_metric when available) with those shapes."""
4545
rng = np.random.RandomState(0)
4646
d_min, d_max = 4, 12
4747
r_min, r_max = 2, 6
@@ -67,14 +67,22 @@ def test_random_d_r(nx):
6767
# ---------------------------------------------------------------------
6868

6969

70-
def test_delta_identity():
70+
def test_eigenvalue_cost_matrix_simple():
71+
Ds = np.array([0.0, 1.0])
72+
Dt = np.array([0.0, 2.0])
73+
C = eigenvalue_cost_matrix(Ds, Dt, q=2)
74+
expected = np.array([[0.0, 4.0], [1.0, 1.0]])
75+
np.testing.assert_allclose(C, expected)
76+
77+
78+
def test_delta_matrix_1d_identity():
7179
r = 4
7280
I = np.eye(r, dtype=complex)
7381
delta = _delta_matrix_1d(I, I, I, I)
7482
np.testing.assert_allclose(delta, np.eye(r), atol=1e-12)
7583

7684

77-
def test_delta_swap_invariance():
85+
def test_delta_matrix_1d_swap_invariance():
7886
d, r = 6, 3
7987
_, R, _, _, _, _ = random_atoms(d=d, r=r)
8088
L = R.copy()
@@ -98,7 +106,7 @@ def test_grassmann_zero_distance(grassman_metric, nx):
98106
np.testing.assert_allclose(dist2_np, 0.0, atol=1e-12)
99107

100108

101-
def test_grassmann_invalid_name():
109+
def test_grassmann_distance_invalid_name():
102110
delta = np.ones((2, 2))
103111
with pytest.raises(ValueError):
104112
_grassmann_distance_squared(delta, grassman_metric="cordal")
@@ -110,7 +118,7 @@ def test_grassmann_invalid_name():
110118

111119

112120
def test_cost_self_zero(nx):
113-
"""(D_S R_S L_S D_S): diagonal of cost matrix (same atom to same atom) should be near zero."""
121+
"""(D_S R_S L_S D_S): diagonal of sgot_cost_matrix matrix (same atom to same atom) should be near zero."""
114122
Ds, Rs, Ls, _, _, _ = random_atoms()
115123
Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2 = nx.from_numpy(Ds, Rs, Ls, Ds, Rs, Ls)
116124
C = sgot_cost_matrix(Ds_b, Rs_b, Ls_b, Ds_b2, Rs_b2, Ls_b2)
@@ -119,7 +127,7 @@ def test_cost_self_zero(nx):
119127
np.testing.assert_allclose(C_np, C_np.T, atol=1e-10)
120128

121129

122-
def test_cost_reference(nx):
130+
def test_grassmann_cost_reference(nx):
123131
"""Cost with same inputs and HPs should be deterministic (np.testing.assert_allclose)."""
124132
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
125133
Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt)
@@ -132,7 +140,7 @@ def test_cost_reference(nx):
132140
@pytest.mark.parametrize(
133141
"grassman_metric", ["geodesic", "chordal", "procrustes", "martin"]
134142
)
135-
def test_cost_basic(grassman_metric, nx):
143+
def test_grassmann_cost_basic_properties(grassman_metric, nx):
136144
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
137145
Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt)
138146
C = sgot_cost_matrix(
@@ -144,7 +152,7 @@ def test_cost_basic(grassman_metric, nx):
144152
assert np.all(C_np >= 0)
145153

146154

147-
def test_cost_validation():
155+
def test_sgot_cost_input_validation():
148156
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
149157

150158
with pytest.raises(ValueError):
@@ -159,21 +167,21 @@ def test_cost_validation():
159167
# ---------------------------------------------------------------------
160168

161169

162-
def test_metric_self_zero():
170+
def test_sgot_metric_self_zero():
163171
Ds, Rs, Ls, _, _, _ = random_atoms()
164172
dist = sgot_metric(Ds, Rs, Ls, Ds, Rs, Ls)
165173
assert np.isfinite(dist)
166174
assert abs(dist) < 5e-4
167175

168176

169-
def test_metric_symmetry():
177+
def test_sgot_metric_symmetry():
170178
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
171179
d1 = sgot_metric(Ds, Rs, Ls, Dt, Rt, Lt)
172180
d2 = sgot_metric(Dt, Rt, Lt, Ds, Rs, Ls)
173181
np.testing.assert_allclose(d1, d2, atol=1e-8)
174182

175183

176-
def test_metric_with_weights():
184+
def test_sgot_metric_with_weights():
177185
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
178186
r = Ds.shape[0]
179187

@@ -193,16 +201,20 @@ def test_metric_with_weights():
193201
# ---------------------------------------------------------------------
194202

195203

196-
def test_hyperparameter_sweep_cost(nx):
197-
"""Sweep over a random set of HPs and run cost()."""
198-
grassmann_types = ["geodesic", "chordal", "procrustes", "martin"]
204+
@pytest.mark.parametrize(
205+
"eta, p, q, grassman_metric",
206+
[
207+
(0.5, 1, 1, "geodesic"),
208+
(0.5, 2, 1, "chordal"),
209+
(0.3, 2, 2, "procrustes"),
210+
(0.7, 1, 2, "martin"),
211+
],
212+
)
213+
def test_hyperparameter_sweep_cost(nx, eta, p, q, grassman_metric):
214+
"""Sweep over a set of fixed HPs and run cost()."""
199215
Ds, Rs, Ls, Dt, Rt, Lt = random_atoms()
200216
Ds_b, Rs_b, Ls_b, Dt_b, Rt_b, Lt_b = nx.from_numpy(Ds, Rs, Ls, Dt, Rt, Lt)
201-
rng = np.random.RandomState(2)
202-
eta = rng.uniform(0.0, 1.0)
203-
p = rng.choice([1, 2])
204-
q = rng.choice([1, 2])
205-
gm = rng.choice(grassmann_types)
217+
206218
C = sgot_cost_matrix(
207219
Ds_b,
208220
Rs_b,
@@ -213,7 +225,7 @@ def test_hyperparameter_sweep_cost(nx):
213225
eta=eta,
214226
p=p,
215227
q=q,
216-
grassman_metric=gm,
228+
grassman_metric=grassman_metric,
217229
)
218230
C_np = nx.to_numpy(C)
219231
assert C_np.shape == (Ds.shape[0], Dt.shape[0])

0 commit comments

Comments
 (0)