Skip to content

Commit d1f274d

Browse files
committed
tests lcot and lssw
1 parent 63447e8 commit d1f274d

File tree

6 files changed

+273
-6
lines changed

6 files changed

+273
-6
lines changed

examples/plot_compute_wasserstein_circle.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,23 @@ def pdf_von_Mises(theta, mu, kappa):
172172
xts[i, k] = xt / (2 * np.pi)
173173

174174
L_w2 = np.zeros((n_try, 100))
175+
L_lcot = np.zeros((n_try, 100))
175176
for i in range(n_try):
176177
L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
178+
L_lcot[i] = ot.linear_circular_ot(xts[i].T)
177179

178180
m_w2 = np.mean(L_w2, axis=0)
179181
std_w2 = np.std(L_w2, axis=0)
180182

183+
m_lcot = np.mean(L_lcot, axis=0)
184+
std_lcot = np.mean(L_lcot, axis=0)
185+
181186
pl.figure(1)
182-
pl.plot(kappas, m_w2)
187+
pl.plot(kappas, m_w2, label="Wasserstein")
183188
pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
189+
pl.plot(kappas, m_lcot, label="LCOT")
190+
pl.fill_between(kappas, m_lcot - std_lcot, m_lcot + std_lcot, alpha=0.5)
191+
pl.legend()
184192
pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
185193
pl.xlabel(r"$\kappa$")
186194
pl.show()

ot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
max_sliced_wasserstein_distance,
5959
sliced_wasserstein_sphere,
6060
sliced_wasserstein_sphere_unif,
61+
linear_sliced_wasserstein_sphere,
6162
)
6263
from .gromov import (
6364
gromov_wasserstein,
@@ -106,6 +107,7 @@
106107
"sinkhorn_unbalanced2",
107108
"sliced_wasserstein_distance",
108109
"sliced_wasserstein_sphere",
110+
"linear_sliced_wasserstein_sphere",
109111
"gromov_wasserstein",
110112
"gromov_wasserstein2",
111113
"gromov_barycenters",

ot/lp/solver_1d.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,19 @@ def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None):
12281228
if v_values is None:
12291229
dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u))
12301230
return nx.mean(dist_u**2, axis=0)
1231+
else:
1232+
m = v_values.shape[0]
1233+
if len(v_values.shape) == 1:
1234+
v_values = nx.reshape(v_values, (m, 1))
1235+
1236+
if u_values.shape[1] != v_values.shape[1]:
1237+
raise ValueError(
1238+
"u and v must have the same number of batchs {} and {} respectively given".format(
1239+
u_values.shape[1], v_values.shape[1]
1240+
)
1241+
)
12311242

12321243
emb_v = linear_circular_embedding(unif_s1, v_values, v_weights)
1244+
12331245
dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v))
12341246
return nx.mean(dist_uv**2, axis=0)

ot/sliced.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,8 @@ def sliced_wasserstein_sphere(
454454
----------
455455
.. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
456456
"""
457+
d = X_s.shape[-1]
458+
457459
if a is not None and b is not None:
458460
nx = get_backend(X_s, X_t, a, b)
459461
else:
@@ -470,11 +472,16 @@ def sliced_wasserstein_sphere(
470472
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)):
471473
raise ValueError("X_t is not on the sphere.")
472474

473-
Xps_coords, projections = projection_sphere_to_circle(
475+
if projections is None:
476+
projections = get_projections_sphere(
477+
d, n_projections, seed=seed, backend=nx, type_as=X_s
478+
)
479+
480+
Xps_coords, _ = projection_sphere_to_circle(
474481
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
475482
)
476483

477-
Xpt_coords, projections = projection_sphere_to_circle(
484+
Xpt_coords, _ = projection_sphere_to_circle(
478485
X_t, n_projections=n_projections, projections=projections, seed=seed, backend=nx
479486
)
480487

@@ -536,6 +543,8 @@ def sliced_wasserstein_sphere_unif(
536543
-----------
537544
.. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations.
538545
"""
546+
d = X_s.shape[-1]
547+
539548
if a is not None:
540549
nx = get_backend(X_s, a)
541550
else:
@@ -544,7 +553,12 @@ def sliced_wasserstein_sphere_unif(
544553
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)):
545554
raise ValueError("X_s is not on the sphere.")
546555

547-
Xps_coords, projections = projection_sphere_to_circle(
556+
if projections is None:
557+
projections = get_projections_sphere(
558+
d, n_projections, seed=seed, backend=nx, type_as=X_s
559+
)
560+
561+
Xps_coords, _ = projection_sphere_to_circle(
548562
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
549563
)
550564

@@ -616,6 +630,7 @@ def linear_sliced_wasserstein_sphere(
616630
----------
617631
.. [77] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations.
618632
"""
633+
d = X_s.shape[-1]
619634

620635
if a is not None and b is not None:
621636
nx = get_backend(X_s, X_t, a, b)
@@ -633,12 +648,17 @@ def linear_sliced_wasserstein_sphere(
633648
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)):
634649
raise ValueError("X_t is not on the sphere.")
635650

636-
Xps_coords, projections = projection_sphere_to_circle(
651+
if projections is None:
652+
projections = get_projections_sphere(
653+
d, n_projections, seed=seed, backend=nx, type_as=X_s
654+
)
655+
656+
Xps_coords, _ = projection_sphere_to_circle(
637657
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
638658
)
639659

640660
if X_t is not None:
641-
Xpt_coords, projections = projection_sphere_to_circle(
661+
Xpt_coords, _ = projection_sphere_to_circle(
642662
X_t,
643663
n_projections=n_projections,
644664
projections=projections,

test/test_1d_solver.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,54 @@ def test_wasserstein_circle_bad_shape():
355355

356356
with pytest.raises(ValueError):
357357
_ = ot.wasserstein_circle(u, v, p=1)
358+
359+
360+
def test_linear_circular_ot_devices(nx):
361+
rng = np.random.RandomState(0)
362+
363+
n = 10
364+
x = np.linspace(0, 1, n)
365+
rho_u = np.abs(rng.randn(n))
366+
rho_u /= rho_u.sum()
367+
rho_v = np.abs(rng.randn(n))
368+
rho_v /= rho_v.sum()
369+
370+
for tp in nx.__type_list__:
371+
print(nx.dtype_device(tp))
372+
373+
xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
374+
375+
lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb)
376+
377+
nx.assert_same_dtype_device(xb, lcot)
378+
379+
380+
def test_linear_circular_ot_bad_shape():
381+
n = 20
382+
m = 30
383+
rng = np.random.RandomState(0)
384+
u = rng.rand(n, 2)
385+
v = rng.rand(m, 1)
386+
387+
with pytest.raises(ValueError):
388+
_ = ot.linear_circular_ot(u, v)
389+
390+
391+
def test_linear_circular_ot_same_dist():
392+
n = 20
393+
rng = np.random.RandomState(0)
394+
u = rng.rand(n)
395+
396+
lcot = ot.linear_circular_ot(u, u)
397+
np.testing.assert_almost_equal(lcot, 0.0)
398+
399+
400+
def test_linear_circular_ot_different_dist():
401+
n = 20
402+
m = 30
403+
rng = np.random.RandomState(0)
404+
u = rng.rand(n)
405+
v = rng.rand(m)
406+
407+
lcot = ot.linear_circular_ot(u, v)
408+
assert lcot > 0.0

test/test_sliced.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,30 @@ def test_projections_stiefel():
289289
np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])
290290
)
291291

292+
rng = np.random.RandomState(0)
293+
294+
projections = ot.sliced.get_projections_sphere(3, n_projs, seed=rng)
295+
projections_T = np.transpose(projections, [0, 2, 1])
296+
297+
np.testing.assert_almost_equal(
298+
np.matmul(projections_T, projections),
299+
np.array([np.eye(2) for k in range(n_projs)]),
300+
)
301+
302+
# np.testing.assert_almost_equal(projections, P)
303+
304+
305+
def test_projections_sphere_to_circle():
306+
rng = np.random.RandomState(0)
307+
308+
n_projs = 500
309+
x = rng.randn(100, 3)
310+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
311+
312+
x_projs, _ = ot.sliced.projection_sphere_to_circle(x, n_projs)
313+
assert x_projs.shape == (n_projs, 100)
314+
assert np.all(x_projs >= 0) and np.all(x_projs < 1)
315+
292316

293317
def test_sliced_sphere_same_dist():
294318
n = 100
@@ -506,3 +530,153 @@ def test_sliced_sphere_unif_backend_type_devices(nx):
506530
valb = ot.sliced_wasserstein_sphere_unif(xb)
507531

508532
nx.assert_same_dtype_device(xb, valb)
533+
534+
535+
def test_linear_sliced_sphere_same_dist():
536+
n = 100
537+
rng = np.random.RandomState(0)
538+
539+
x = rng.randn(n, 3)
540+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
541+
u = ot.utils.unif(n)
542+
543+
res = ot.linear_sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng)
544+
np.testing.assert_almost_equal(res, 0.0)
545+
546+
547+
def test_linear_sliced_sphere_same_proj():
548+
n_projections = 10
549+
n = 100
550+
rng = np.random.RandomState(0)
551+
552+
x = rng.randn(n, 3)
553+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
554+
555+
y = rng.randn(n, 3)
556+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
557+
558+
seed = 42
559+
560+
cost1, log1 = ot.linear_sliced_wasserstein_sphere(
561+
x, y, seed=seed, n_projections=n_projections, log=True
562+
)
563+
cost2, log2 = ot.linear_sliced_wasserstein_sphere(
564+
x, y, seed=seed, n_projections=n_projections, log=True
565+
)
566+
567+
assert np.allclose(log1["projections"], log2["projections"])
568+
assert np.isclose(cost1, cost2)
569+
570+
571+
def test_linear_sliced_sphere_bad_shapes():
572+
n = 100
573+
rng = np.random.RandomState(0)
574+
575+
x = rng.randn(n, 3)
576+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
577+
578+
y = rng.randn(n, 4)
579+
y = y / np.sqrt(np.sum(x**2, -1, keepdims=True))
580+
581+
u = ot.utils.unif(n)
582+
583+
with pytest.raises(ValueError):
584+
_ = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
585+
586+
587+
def test_linear_sliced_sphere_values_on_the_sphere():
588+
n = 100
589+
rng = np.random.RandomState(0)
590+
591+
x = rng.randn(n, 3)
592+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
593+
594+
y = rng.randn(n, 4)
595+
596+
u = ot.utils.unif(n)
597+
598+
with pytest.raises(ValueError):
599+
_ = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
600+
601+
602+
def test_linear_sliced_sphere_log():
603+
n = 100
604+
rng = np.random.RandomState(0)
605+
606+
x = rng.randn(n, 4)
607+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
608+
y = rng.randn(n, 4)
609+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
610+
u = ot.utils.unif(n)
611+
612+
res, log = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng, log=True)
613+
assert len(log) == 2
614+
projections = log["projections"]
615+
projected_emds = log["projected_emds"]
616+
617+
assert projections.shape[0] == len(projected_emds) == 10
618+
for emd in projected_emds:
619+
assert emd > 0
620+
621+
622+
def test_linear_sliced_sphere_different_dists():
623+
n = 100
624+
rng = np.random.RandomState(0)
625+
626+
x = rng.randn(n, 3)
627+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
628+
629+
u = ot.utils.unif(n)
630+
y = rng.randn(n, 3)
631+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
632+
633+
res = ot.linear_sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng)
634+
assert res > 0.0
635+
636+
637+
def test_1d_linear_sliced_sphere_equals_emd():
638+
n = 100
639+
m = 120
640+
rng = np.random.RandomState(0)
641+
642+
x = rng.randn(n, 2)
643+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
644+
x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
645+
a = rng.uniform(0, 1, n)
646+
a /= a.sum()
647+
648+
y = rng.randn(m, 2)
649+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
650+
y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi)
651+
u = ot.utils.unif(m)
652+
653+
res = ot.linear_sliced_wasserstein_sphere(x, y, a, u, 100, seed=42)
654+
expected = ot.linear_circular_ot(x_coords.T, y_coords.T, a, u)
655+
656+
np.testing.assert_almost_equal(res**2, expected, decimal=5)
657+
658+
659+
def test_linear_sliced_sphere_backend_type_devices(nx):
660+
n = 100
661+
rng = np.random.RandomState(0)
662+
663+
x = rng.randn(n, 3)
664+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
665+
666+
y = rng.randn(2 * n, 3)
667+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
668+
669+
sw_np, log = ot.linear_sliced_wasserstein_sphere(x, y, log=True)
670+
P = log["projections"]
671+
672+
for tp in nx.__type_list__:
673+
print(nx.dtype_device(tp))
674+
675+
xb, yb = nx.from_numpy(x, y, type_as=tp)
676+
677+
valb = ot.linear_sliced_wasserstein_sphere(
678+
xb, yb, projections=nx.from_numpy(P, type_as=tp)
679+
)
680+
681+
nx.assert_same_dtype_device(xb, valb)
682+
np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))

0 commit comments

Comments
 (0)