Skip to content

Commit 9c82f38

Browse files
committed
Add doc
1 parent 88b908c commit 9c82f38

File tree

1 file changed

+63
-23
lines changed

1 file changed

+63
-23
lines changed

ot/lp/solver_1d.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,8 @@ def wasserstein_circle(
933933
eps=1e-6,
934934
require_sort=True,
935935
):
936-
r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or
937-
the binary search algorithm proposed in [44] otherwise.
936+
r"""Computes the Wasserstein distance on the circle using either :ref:`[45] <references-wasserstein-circle>` for p=1 or
937+
the binary search algorithm proposed in :ref:`[44] <references-wasserstein-circle>` otherwise.
938938
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
939939
takes the value modulo 1.
940940
If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates
@@ -996,6 +996,8 @@ def wasserstein_circle(
996996
>>> wasserstein_circle(u.T, v.T)
997997
array([0.1])
998998
999+
1000+
.. _references-wasserstein-circle:
9991001
References
10001002
----------
10011003
.. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82.
@@ -1042,7 +1044,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
10421044
.. math::
10431045
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
10441046
1045-
using e.g. ot.utils.get_coordinate_circle(x)
1047+
using e.g. ot.utils.get_coordinate_circle(x).
10461048
10471049
Parameters
10481050
----------
@@ -1098,14 +1100,32 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
10981100

10991101

11001102
def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True):
1101-
"""
1102-
Inputs:
1103-
- x: shape (m,), points where we evaluate the embedding
1104-
- u_values: shape (n, ...) (coordinates on [0,1[)
1105-
- u_weights: shape (n, ...)
1103+
r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference
1104+
:math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`.
1105+
1106+
For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[76] <references-lcot>`)
1107+
1108+
.. math``
1109+
\hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x.
1110+
1111+
Parameters
1112+
----------
1113+
x : ndary, shape (m,)
1114+
Points in [0,1[ where to evaluate the embedding
1115+
u_values : ndarray, shape (n, ...)
1116+
samples in the source domain (coordinates on [0,1[)
1117+
u_weights : ndarray, shape (n, ...), optional
1118+
samples weights in the source domain
1119+
1120+
Returns
1121+
-------
1122+
embedding: ndarray of shape (m, ...)
1123+
Embedding evaluated at :math:`x`
11061124
1107-
Output:
1108-
- embedding of shape (m, ...)
1125+
.. _references-lcot:
1126+
References
1127+
----------
1128+
.. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations.
11091129
"""
11101130
if u_weights is not None:
11111131
nx = get_backend(u_values, u_weights)
@@ -1140,27 +1160,47 @@ def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True):
11401160
return (u_quantiles - x[:, None]) % 1
11411161

11421162

1143-
def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None, p=2):
1144-
"""
1145-
LCOT from [1]
1163+
def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None):
1164+
r"""Computes the Linear Circular Optimal Transport distance from :ref:`[76] <references-lcot>` using :math:`\eta=\mathrm{Unif}(S^1)`
1165+
as reference measure.
1166+
Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`,
1167+
takes the value modulo 1.
1168+
If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates
1169+
using e.g. the atan2 function.
1170+
1171+
General loss returned:
1172+
1173+
.. math::
1174+
\mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t
1175+
1176+
where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`,
1177+
and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`.
11461178
11471179
Parameters
11481180
----------
1149-
1150-
Inputs:
1151-
- u_values: shape (n, ...) - samples in the source domain (coordinates on [0,1[)
1152-
- v_values: shape (m, ...) , optional- samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution
1153-
- u_weights: shape (n, ...), optional - weights of the first empirical distribution, if None then uniform weights are used
1154-
- v_weights, shape (m, ...), optional - weights of the second empirical distribution, if None then uniform weights are used
1181+
u_values : ndarray, shape (n, ...)
1182+
samples in the source domain (coordinates on [0,1[)
1183+
v_values : ndarray, shape (n, ...), optional
1184+
samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution
1185+
u_weights : ndarray, shape (n, ...), optional
1186+
samples weights in the source domain
1187+
v_weights : ndarray, shape (n, ...), optional
1188+
samples weights in the target domain
11551189
11561190
Returns
11571191
-------
1158-
Outputs:
1159-
- return batchs LCOT
1192+
loss: float
1193+
Cost associated to the linear optimal transportation
11601194
11611195
Examples
11621196
--------
1197+
>>> u = np.array([[0.2,0.5,0.8]])%1
1198+
>>> v = np.array([[0.4,0.5,0.7]])%1
1199+
>>> linear_circular_ot(u.T, v.T)
1200+
array([0.0127])
1201+
11631202
1203+
.. _references-lcot:
11641204
References
11651205
----------
11661206
.. [76] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations.
@@ -1187,8 +1227,8 @@ def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None,
11871227

11881228
if v_values is None:
11891229
dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u))
1190-
return nx.mean(dist_u**p, axis=0)
1230+
return nx.mean(dist_u**2, axis=0)
11911231

11921232
emb_v = linear_circular_embedding(unif_s1, v_values, v_weights)
11931233
dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v))
1194-
return nx.mean(dist_uv**p, axis=0)
1234+
return nx.mean(dist_uv**2, axis=0)

0 commit comments

Comments
 (0)