|
64 | 64 | n = 136 # number of points of the barycentre |
65 | 65 | d = 2 # dimensions of the original measure |
66 | 66 | K = 4 # number of measures to barycentre |
67 | | -m = 50 # number of points of the measures |
68 | | -b_list = [torch.ones(m) / m] * K # weights of the 4 measures |
| 67 | +m_list = [49, 50, 51, 51] # number of points of the measures |
| 68 | +b_list = [torch.ones(m) / m for m in m_list] # weights of the 4 measures |
69 | 69 | weights = torch.ones(K) / K # weights for the barycentre |
70 | 70 | stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo |
71 | 71 |
|
@@ -94,7 +94,7 @@ def proj_circle(X, origin, radius): |
94 | 94 | # onto the K circles |
95 | 95 | Y_list = [] |
96 | 96 | for k in range(K): |
97 | | - t = torch.rand(m) * 2 * np.pi |
| 97 | + t = torch.rand(m_list[k]) * 2 * np.pi |
98 | 98 | X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) |
99 | 99 | X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] |
100 | 100 | Y_list.append(P_list[k](X_temp)) |
@@ -244,8 +244,10 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): |
244 | 244 | plt.tight_layout() |
245 | 245 |
|
246 | 246 | # %% |
247 | | -# Plot energy convergence |
248 | | -fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
| 247 | +# Plot energy convergence and support sizes |
| 248 | +size = 3 |
| 249 | +n_plots = 4 |
| 250 | +fig, axes = plt.subplots(1, n_plots, figsize=(size * n_plots, size)) |
249 | 251 | V_list = [V.item() for V in log_dict["V_list"]] |
250 | 252 | V_list2 = [V.item() for V in log_dict2["V_list"]] |
251 | 253 | diff = np.array(V_list2) - np.array(V_list) |
@@ -277,6 +279,23 @@ def B(y, its=150, lr=1, stop_threshold=stop_threshold): |
277 | 279 | axes[2].set_yscale("log") |
278 | 280 | axes[2].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
279 | 281 |
|
| 282 | +# plot support sizes |
| 283 | +support_sizes = [Xi.shape[0] for Xi in log_dict["X_list"]] |
| 284 | +support_sizes2 = [Xi.shape[0] for Xi in log_dict2["X_list"]] |
| 285 | + |
| 286 | +axes[3].plot(support_sizes, color="C0", lw=5, alpha=0.6, label="True FP") |
| 287 | +axes[3].scatter( |
| 288 | + range(len(support_sizes)), support_sizes, color="blue", alpha=0.8, s=100 |
| 289 | +) |
| 290 | +axes[3].plot(support_sizes2, color="red", lw=5, alpha=0.6, label="Heur. FP") |
| 291 | +axes[3].scatter( |
| 292 | + range(len(support_sizes2)), support_sizes2, color="red", alpha=0.8, s=100 |
| 293 | +) |
| 294 | +axes[3].legend(loc="best") |
| 295 | +axes[3].set_xlabel("Iteration") |
| 296 | +axes[3].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) |
| 297 | +axes[3].set_title("Support Sizes") |
| 298 | + |
280 | 299 | plt.tight_layout() |
281 | 300 | plt.show() |
282 | 301 |
|
|
0 commit comments