Skip to content

Commit 63447e8

Browse files
committed
Add LSSOT
1 parent dfaaf7b commit 63447e8

File tree

3 files changed

+222
-54
lines changed

3 files changed

+222
-54
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,5 @@ Artificial Intelligence.
391391
[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.
392392

393393
[76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations.
394+
395+
[77] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations.

ot/lp/solver_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None):
12211221
elif u_weights.ndim != u_values.ndim:
12221222
u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1)
12231223

1224-
unif_s1 = nx.linspace(0, 1, 101)[:-1]
1224+
unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1]
12251225

12261226
emb_u = linear_circular_embedding(unif_s1, u_values, u_weights)
12271227

ot/sliced.py

Lines changed: 219 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
import numpy as np
1313
from .backend import get_backend, NumpyBackend
1414
from .utils import list_to_array, get_coordinate_circle
15-
from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
15+
from .lp import (
16+
wasserstein_circle,
17+
semidiscrete_wasserstein2_unif_circle,
18+
linear_circular_ot,
19+
)
1620

1721

1822
def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None):
@@ -284,6 +288,109 @@ def max_sliced_wasserstein_distance(
284288
return res
285289

286290

291+
def get_projections_sphere(d, n_projections, seed=None, backend=None, type_as=None):
292+
r"""
293+
Generates n_projections samples from the uniform distribution on the Stiefel manifold of dimension :math:`d\times 2`: :math:`\mathbb{V}_{d,2}=\{X \in \mathbb{R}^{d\times 2}, X^TX=I_2\}`
294+
295+
Parameters
296+
----------
297+
d : int
298+
dimension of the space
299+
n_projections : int
300+
number of samples requested
301+
seed: int or RandomState, optional
302+
Seed used for numpy random number generator
303+
backend:
304+
Backend to use for random generation
305+
type_as: optional
306+
Type to use for random generation
307+
308+
Returns
309+
-------
310+
out: ndarray, shape (n_projections, d, 2)
311+
312+
Examples
313+
--------
314+
>>> n_projections = 100
315+
>>> d = 5
316+
>>> projs = get_projections_sphere(d, n_projections)
317+
>>> np.allclose(np.einsum("nij, njk -> nik", projs, projs), np.eye(2)) # doctest: +NORMALIZE_WHITESPACE
318+
True
319+
"""
320+
if backend is None:
321+
nx = NumpyBackend()
322+
else:
323+
nx = backend
324+
325+
if isinstance(seed, np.random.RandomState) and str(nx) == "numpy":
326+
Z = seed.randn(n_projections, d, 2)
327+
else:
328+
if seed is not None:
329+
nx.seed(seed)
330+
Z = nx.randn(n_projections, d, 2, type_as=type_as)
331+
332+
projections, _ = nx.qr(Z)
333+
return projections
334+
335+
336+
def projection_sphere_to_circle(
337+
x, n_projections=50, projections=None, seed=None, backend=None
338+
):
339+
r"""
340+
Projection of :math:`x\in S^{d-1}` on circles using coordinates on [0,1[.
341+
342+
To get the projection on the circle, we use the following formula:
343+
.. math::
344+
P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}
345+
346+
where :math:`U` is a random matrix sampled from the uniform distribution on the Stiefel manifold of dimension :math:`d\times 2`: :math:`\mathbb{V}_{d,2}=\{X \in \mathbb{R}^{d\times 2}, X^TX=I_2\}`
347+
and :math:`x` is a point on the sphere. Then, we apply the function get_coordinate_circle to get the coordinates on :math:`[0,1[`.
348+
349+
Parameters
350+
----------
351+
x : ndarray, shape (n_samples, dim)
352+
samples on the sphere
353+
n_projections : int, optional
354+
Number of projections used for the Monte-Carlo approximation
355+
projections: shape (n_projections, dim, 2), optional
356+
Projection matrix (n_projections and seed are not used in this case)
357+
seed: int or RandomState or None, optional
358+
Seed used for random number generator
359+
backend:
360+
Backend to use for random generation
361+
362+
Returns
363+
-------
364+
Xp_coords: ndarray, shape (n_projections, n_samples)
365+
Coordinates of the projections on the circle
366+
"""
367+
if backend is None:
368+
nx = get_backend(x)
369+
else:
370+
nx = backend
371+
372+
n, d = x.shape
373+
374+
if projections is None:
375+
projections = get_projections_sphere(
376+
d, n_projections, seed=seed, backend=nx, type_as=x
377+
)
378+
379+
# Projection on S^1
380+
# Projection on plane
381+
Xp = nx.einsum("ikj, lk -> ilj", projections, x)
382+
383+
# Projection on sphere
384+
Xp = Xp / nx.sqrt(nx.sum(Xp**2, -1, keepdims=True))
385+
386+
# Get coordinates on [0,1[
387+
Xp_coords = nx.reshape(
388+
get_coordinate_circle(nx.reshape(Xp, (-1, 2))), (n_projections, n)
389+
)
390+
391+
return Xp_coords, projections
392+
393+
287394
def sliced_wasserstein_sphere(
288395
X_s,
289396
X_t,
@@ -352,9 +459,6 @@ def sliced_wasserstein_sphere(
352459
else:
353460
nx = get_backend(X_s, X_t)
354461

355-
n, d = X_s.shape
356-
m, _ = X_t.shape
357-
358462
if X_s.shape[1] != X_t.shape[1]:
359463
raise ValueError(
360464
"X_s and X_t must have the same number of dimensions {} and {} respectively given".format(
@@ -366,34 +470,12 @@ def sliced_wasserstein_sphere(
366470
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)):
367471
raise ValueError("X_t is not on the sphere.")
368472

369-
if projections is None:
370-
# Uniforms and independent samples on the Stiefel manifold V_{d,2}
371-
if isinstance(seed, np.random.RandomState) and str(nx) == "numpy":
372-
Z = seed.randn(n_projections, d, 2)
373-
else:
374-
if seed is not None:
375-
nx.seed(seed)
376-
Z = nx.randn(n_projections, d, 2, type_as=X_s)
377-
378-
projections, _ = nx.qr(Z)
379-
else:
380-
n_projections = projections.shape[0]
381-
382-
# Projection on S^1
383-
# Projection on plane
384-
Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
385-
Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t)
386-
387-
# Projection on sphere
388-
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
389-
Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
390-
391-
# Get coordinates on [0,1[
392-
Xps_coords = nx.reshape(
393-
get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)
473+
Xps_coords, projections = projection_sphere_to_circle(
474+
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
394475
)
395-
Xpt_coords = nx.reshape(
396-
get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)
476+
477+
Xpt_coords, projections = projection_sphere_to_circle(
478+
X_t, n_projections=n_projections, projections=projections, seed=seed, backend=nx
397479
)
398480

399481
projected_emd = wasserstein_circle(
@@ -406,7 +488,9 @@ def sliced_wasserstein_sphere(
406488
return res
407489

408490

409-
def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False):
491+
def sliced_wasserstein_sphere_unif(
492+
X_s, a=None, n_projections=50, projections=None, seed=None, log=False
493+
):
410494
r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution.
411495
412496
.. math::
@@ -425,6 +509,8 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
425509
samples weights in the source domain
426510
n_projections : int, optional
427511
Number of projections used for the Monte-Carlo approximation
512+
projections: shape (n_projections, dim, 2), optional
513+
Projection matrix (n_projections and seed are not used in this case)
428514
seed: int or RandomState or None, optional
429515
Seed used for random number generator
430516
log: bool, optional
@@ -455,36 +541,116 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
455541
else:
456542
nx = get_backend(X_s)
457543

458-
n, d = X_s.shape
459-
460544
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)):
461545
raise ValueError("X_s is not on the sphere.")
462546

463-
# Uniforms and independent samples on the Stiefel manifold V_{d,2}
464-
if isinstance(seed, np.random.RandomState) and str(nx) == "numpy":
465-
Z = seed.randn(n_projections, d, 2)
466-
else:
467-
if seed is not None:
468-
nx.seed(seed)
469-
Z = nx.randn(n_projections, d, 2, type_as=X_s)
547+
Xps_coords, projections = projection_sphere_to_circle(
548+
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
549+
)
470550

471-
projections, _ = nx.qr(Z)
551+
projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
552+
res = nx.mean(projected_emd) ** (1 / 2)
472553

473-
# Projection on S^1
474-
# Projection on plane
475-
Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
554+
if log:
555+
return res, {"projections": projections, "projected_emds": projected_emd}
556+
return res
476557

477-
# Projection on sphere
478-
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
479558

480-
# Get coordinates on [0,1[
481-
Xps_coords = nx.reshape(
482-
get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)
559+
def linear_sliced_wasserstein_sphere(
560+
X_s,
561+
X_t=None,
562+
a=None,
563+
b=None,
564+
n_projections=50,
565+
projections=None,
566+
seed=None,
567+
log=False,
568+
):
569+
r"""Computes the linear spherical sliced wasserstein distance from :ref:`[77] <references-lssot>`.
570+
571+
General loss returned:
572+
573+
.. math::
574+
\mathrm{LSSOT}_2(\mu, \nu) = \left(\int_{\mathbb{V}_{d,2}} \mathrm{LCOT}_2^2(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac12},
575+
576+
where :math:`\mu,\nu\in\mathcal{P}(S^{d-1})` are two probability measures on the sphere, :math:`\mathrm{LCOT}_2` is the linear circular optimal transport distance,
577+
and :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}`.
578+
579+
Parameters
580+
----------
581+
X_s: ndarray, shape (n_samples_a, dim)
582+
Samples in the source domain
583+
X_t: ndarray, shape (n_samples_b, dim), optional
584+
Samples in the target domain. If None, computes the distance against the uniform distribution on the sphere.
585+
a : ndarray, shape (n_samples_a,), optional
586+
samples weights in the source domain
587+
b : ndarray, shape (n_samples_b,), optional
588+
samples weights in the target domain
589+
n_projections : int, optional
590+
Number of projections used for the Monte-Carlo approximation
591+
projections: shape (n_projections, dim, 2), optional
592+
Projection matrix (n_projections and seed are not used in this case)
593+
seed: int or RandomState or None, optional
594+
Seed used for random number generator
595+
log: bool, optional
596+
if True, linear_sliced_wasserstein_sphere returns the projections used and their associated LCOT.
597+
598+
Returns
599+
-------
600+
cost: float
601+
Linear Spherical Sliced Wasserstein Cost
602+
log: dict, optional
603+
log dictionary return only if log==True in parameters
604+
605+
Examples
606+
---------
607+
>>> n_samples_a = 20
608+
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
609+
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
610+
>>> linear_sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
611+
0.0
612+
613+
614+
.. _references-lssot:
615+
References
616+
----------
617+
.. [77] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data. International Conference on Learning Representations.
618+
"""
619+
620+
if a is not None and b is not None:
621+
nx = get_backend(X_s, X_t, a, b)
622+
else:
623+
nx = get_backend(X_s, X_t)
624+
625+
if X_s.shape[1] != X_t.shape[1]:
626+
raise ValueError(
627+
"X_s and X_t must have the same number of dimensions {} and {} respectively given".format(
628+
X_s.shape[1], X_t.shape[1]
629+
)
630+
)
631+
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10 ** (-4)):
632+
raise ValueError("X_s is not on the sphere.")
633+
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10 ** (-4)):
634+
raise ValueError("X_t is not on the sphere.")
635+
636+
Xps_coords, projections = projection_sphere_to_circle(
637+
X_s, n_projections=n_projections, projections=projections, seed=seed, backend=nx
483638
)
484639

485-
projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
486-
res = nx.mean(projected_emd) ** (1 / 2)
640+
if X_t is not None:
641+
Xpt_coords, projections = projection_sphere_to_circle(
642+
X_t,
643+
n_projections=n_projections,
644+
projections=projections,
645+
seed=seed,
646+
backend=nx,
647+
)
648+
649+
projected_lcot = linear_circular_ot(
650+
Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b
651+
)
652+
res = nx.mean(projected_lcot) ** (1 / 2)
487653

488654
if log:
489-
return res, {"projections": projections, "projected_emds": projected_emd}
655+
return res, {"projections": projections, "projected_emds": projected_lcot}
490656
return res

0 commit comments

Comments
 (0)