Skip to content

Commit 38e65c9

Browse files
committed
applying PR comments (still some doc rendering issues to fix)
1 parent 6554189 commit 38e65c9

File tree

7 files changed

+177
-83
lines changed

7 files changed

+177
-83
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ POT provides the following generic OT solvers (links to examples):
5454
* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
5555
[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
5656
* Fused unbalanced Gromov-Wasserstein [70].
57-
* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76]
58-
* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76]
57+
* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77]
58+
* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77]
5959

6060
POT provides the following Machine Learning related solvers:
6161

examples/barycenters/plot_free_support_barycenter_generic_cost.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,28 @@
88
a ground cost that is not a power of a norm. We take the example of ground costs
99
:math:`c_k(x, y) = \lambda_k\|P_k(x)-y\|_2^2`, where :math:`P_k` is the
1010
(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A
11-
barycenter is defined ([76]) as a minimiser of the energy :math:`V(\mu) = \sum_k
11+
barycenter is defined ([77]) as a minimiser of the energy :math:`V(\mu) = \sum_k
1212
\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter
1313
measure, the measures :math:`\nu_k` are the target measures and
1414
:math:`\mathcal{T}_{c_k}` is the OT cost for ground cost :math:`c_k`. This is an
15-
example of the fixed-point barycenter solver introduced in [76] which
15+
example of the fixed-point barycenter solver introduced in [77] which
1616
generalises [20] and [43].
1717
1818
The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in
1919
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over
2020
:math:`x` with Pytorch.
2121
22-
We compare two algorithms from [76]: the first ([76], Algorithm 2,
22+
We compare two algorithms from [77]: the first ([77], Algorithm 2,
2323
'true_fixed_point' in POT) has convergence guarantees but the iterations may
2424
increase in support size and thus require more computational resources. The
25-
second ([76], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified
25+
second ([77], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified
2626
heuristic that imposes a fixed support size for the barycenter and fixed
2727
weights.
2828
2929
We initialise both algorithms with a support size of 136, computing a barycenter
3030
between measures with uniform weights and 50 points.
3131
32-
[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
32+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
3333
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
3434
(2024)
3535
@@ -51,7 +51,6 @@
5151
# %%
5252
# Generate data
5353
import torch
54-
import ot
5554
from torch.optim import Adam
5655
from ot.utils import dist
5756
import numpy as np
@@ -62,7 +61,7 @@
6261

6362
torch.manual_seed(42)
6463

65-
n = 136 # number of points of the of the barycentre
64+
n = 136 # number of points of the barycentre
6665
d = 2 # dimensions of the original measure
6766
K = 4 # number of measures to barycentre
6867
m = 50 # number of points of the measures
@@ -204,15 +203,6 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
204203
s = 80
205204
labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
206205

207-
208-
# Compute barycenter energies
209-
def V(X, a):
210-
v = 0
211-
for k in range(K):
212-
v += (1 / K) * ot.emd2(a, b_list[k], cost_list[k](X, Y_list[k]))
213-
return v
214-
215-
216206
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
217207

218208
# Plot for the true fixed-point algorithm
@@ -228,7 +218,7 @@ def V(X, a):
228218
axes[0].set_title(
229219
"True Fixed-Point Algorithm\n"
230220
f"Support size: {a_bar.shape[0]}\n"
231-
f"Barycenter cost: {V(X_bar, a_bar).item():.6f}\n"
221+
f"Barycenter cost: {log_dict['V_list'][-1].item():.6f}\n"
232222
f"Computation time {dt_true_fixed_point:.4f}s"
233223
)
234224
axes[0].axis("equal")
@@ -244,7 +234,7 @@ def V(X, a):
244234
axes[1].set_title(
245235
"Heuristic Barycentric Algorithm\n"
246236
f"Support size: {X_bar2.shape[0]}\n"
247-
f"Barycenter cost: {V(X_bar2, torch.ones(n) / n).item():.6f}\n"
237+
f"Barycenter cost: {log_dict2['V_list'][-1].item():.6f}\n"
248238
f"Computation time {dt_barycentric:.4f}s"
249239
)
250240
axes[1].axis("equal")
@@ -255,10 +245,10 @@ def V(X, a):
255245

256246
# %%
257247
# Plot energy convergence
258-
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
259-
260-
V_list = [V(X, a).item() for (X, a) in zip(log_dict["X_list"], log_dict["a_list"])]
261-
V_list2 = [V(X, torch.ones(n) / n).item() for X in log_dict2["X_list"]]
248+
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
249+
V_list = [V.item() for V in log_dict["V_list"]]
250+
V_list2 = [V.item() for V in log_dict2["V_list"]]
251+
diff = np.array(V_list2) - np.array(V_list)
262252

263253
# Plot for True Fixed-Point Algorithm
264254
axes[0].plot(V_list, lw=5, alpha=0.6)
@@ -278,6 +268,15 @@ def V(X, a):
278268
axes[1].set_yscale("log")
279269
axes[1].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
280270

271+
# Plot difference between the two
272+
axes[2].plot(diff, lw=5, alpha=0.6)
273+
axes[2].scatter(range(len(diff)), diff, color="blue", alpha=0.8, s=100)
274+
axes[2].set_title("Heuristic Fixed-Point Energy - True")
275+
axes[2].set_xlabel("Iteration")
276+
axes[2].set_ylabel("$V_{\\mathrm{heuristic}} - V_{\\mathrm{true}}$")
277+
axes[2].set_yscale("log")
278+
axes[2].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
279+
281280
plt.tight_layout()
282281
plt.show()
283282

examples/barycenters/plot_gmm_barycenter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
This example illustrates the computation of a barycenter between Gaussian
88
Mixtures in the sense of GMM-OT [69]. This computation is done using the
9-
fixed-point method for OT barycenters with generic costs [76], for which POT
9+
fixed-point method for OT barycenters with generic costs [77], for which POT
1010
provides a general solver, and a specific GMM solver. Note that this is a
1111
'free-support' method, implying that the number of components of the barycenter
1212
GMM and their weights are fixed.
@@ -22,7 +22,7 @@
2222
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
2323
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
2424
25-
[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
25+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
2626
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
2727
(2024)
2828

ot/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def gmm_barycenter_fixed_point(
456456
):
457457
r"""
458458
Solves the Gaussian Mixture Model OT barycenter problem (defined in [69])
459-
using the fixed point algorithm (proposed in [76]). The
459+
using the fixed point algorithm (proposed in [77]). The
460460
weights of the barycenter are not optimized, and stay the same as the input
461461
`w_list` or are initialized to uniform.
462462
@@ -504,7 +504,7 @@ def gmm_barycenter_fixed_point(
504504
----------
505505
.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
506506
507-
.. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
507+
.. [77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024)
508508
509509
See Also
510510
--------

ot/lp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
free_support_barycenter,
1616
generalized_free_support_barycenter,
1717
free_support_barycenter_generic_costs,
18+
ot_barycenter_energy,
1819
NorthWestMMGluing,
1920
)
2021
from ..utils import check_number_threads
@@ -49,4 +50,5 @@
4950
"check_number_threads",
5051
"free_support_barycenter_generic_costs",
5152
"NorthWestMMGluing",
53+
"ot_barycenter_energy",
5254
]

0 commit comments

Comments
 (0)