Skip to content

Commit 934af20

Browse files
committed
changed stop threshold method in groundBary and changed the example to have different support sizes
1 parent 5f58e38 commit 934af20

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

examples/barycenters/plot_free_support_barycenter_generic_cost.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@
6464
n = 136 # number of points of the barycentre
6565
d = 2 # dimensions of the original measure
6666
K = 4 # number of measures to barycentre
67-
m = 50 # number of points of the measures
68-
b_list = [torch.ones(m) / m] * K # weights of the 4 measures
67+
m_list = [49, 50, 51, 51] # number of points of the measures
68+
b_list = [torch.ones(m) / m for m in m_list] # weights of the 4 measures
6969
weights = torch.ones(K) / K # weights for the barycentre
7070
stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo
7171

@@ -94,7 +94,7 @@ def proj_circle(X, origin, radius):
9494
# onto the K circles
9595
Y_list = []
9696
for k in range(K):
97-
t = torch.rand(m) * 2 * np.pi
97+
t = torch.rand(m_list[k]) * 2 * np.pi
9898
X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1)
9999
X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :]
100100
Y_list.append(P_list[k](X_temp))
@@ -244,8 +244,10 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
244244
plt.tight_layout()
245245

246246
# %%
247-
# Plot energy convergence
248-
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
247+
# Plot energy convergence and support sizes
248+
size = 3
249+
n_plots = 4
250+
fig, axes = plt.subplots(1, n_plots, figsize=(size * n_plots, size))
249251
V_list = [V.item() for V in log_dict["V_list"]]
250252
V_list2 = [V.item() for V in log_dict2["V_list"]]
251253
diff = np.array(V_list2) - np.array(V_list)
@@ -277,6 +279,23 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold):
277279
axes[2].set_yscale("log")
278280
axes[2].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
279281

282+
# plot support sizes
283+
support_sizes = [Xi.shape[0] for Xi in log_dict["X_list"]]
284+
support_sizes2 = [Xi.shape[0] for Xi in log_dict2["X_list"]]
285+
286+
axes[3].plot(support_sizes, color="C0", lw=5, alpha=0.6, label="True FP")
287+
axes[3].scatter(
288+
range(len(support_sizes)), support_sizes, color="blue", alpha=0.8, s=100
289+
)
290+
axes[3].plot(support_sizes2, color="red", lw=5, alpha=0.6, label="Heur. FP")
291+
axes[3].scatter(
292+
range(len(support_sizes2)), support_sizes2, color="red", alpha=0.8, s=100
293+
)
294+
axes[3].legend(loc="best")
295+
axes[3].set_xlabel("Iteration")
296+
axes[3].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
297+
axes[3].set_title("Support Sizes")
298+
280299
plt.tight_layout()
281300
plt.show()
282301

ot/lp/_barycenter_solvers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,8 @@ def ground_bary(y, x_init):
681681
x = x_init.clone().detach().requires_grad_(True)
682682
solver = Adam if ground_bary_solver == "Adam" else SGD
683683
opt = solver([x], lr=ground_bary_lr)
684-
for _ in range(ground_bary_numItermax):
685-
x_prev = x.data.clone()
684+
loss_prev = None
685+
for i in range(ground_bary_numItermax):
686686
opt.zero_grad()
687687
# inefficient cost computation but compatible
688688
# with the choice of cost_list[k] giving the cost matrix
@@ -693,7 +693,11 @@ def ground_bary(y, x_init):
693693
)
694694
loss.backward()
695695
opt.step()
696-
diff = torch.sum((x.data - x_prev) ** 2)
696+
if i == 0:
697+
diff = ground_bary_stopThr + 1.0
698+
else:
699+
diff = torch.sum((loss.item() - loss_prev) ** 2)
700+
loss_prev = loss.item()
697701
if diff < ground_bary_stopThr:
698702
break
699703
return x.detach()

0 commit comments

Comments
 (0)