|
12 | 12 | import matplotlib.pyplot as plt |
13 | 13 | from ot.partial import partial_wasserstein_1d |
14 | 14 |
|
| 15 | + |
| 16 | +def plot_partial_transport( |
| 17 | + ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None |
| 18 | +): |
| 19 | + y_a = np.ones_like(x_a) |
| 20 | + y_b = -np.ones_like(x_b) |
| 21 | + |
| 22 | + # Plot all points |
| 23 | + ax.plot(x_a, y_a, "o", color="C0", label="x_a") |
| 24 | + ax.plot(x_b, y_b, "o", color="C1", label="x_b") |
| 25 | + |
| 26 | + # Plot transport lines |
| 27 | + if indices_a is not None and indices_b is not None: |
| 28 | + subset_a = np.sort(x_a[indices_a]) |
| 29 | + subset_b = np.sort(x_b[indices_b]) |
| 30 | + |
| 31 | + for x_a_i, x_b_j in zip(subset_a, subset_b): |
| 32 | + ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7) |
| 33 | + |
| 34 | + if marginal_costs is not None: |
| 35 | + k = len(marginal_costs) |
| 36 | + ax.set_title( |
| 37 | + f"Partial Transport - k = {k}, Cumulative Cost = {marginal_costs.sum():.2f}" |
| 38 | + ) |
| 39 | + else: |
| 40 | + ax.set_title("Original 1D Discrete Distributions") |
| 41 | + ax.legend(loc="upper right") |
| 42 | + ax.set_yticks([]) |
| 43 | + ax.set_xticks([]) |
| 44 | + ax.set_ylim(-2, 2) |
| 45 | + ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1) |
| 46 | + ax.axis("off") |
| 47 | + |
| 48 | + |
15 | 49 | # Simulate two 1D discrete distributions |
16 | | -np.random.seed(42) |
| 50 | +np.random.seed(0) |
17 | 51 | n = 6 |
18 | 52 | x_a = np.sort(np.random.uniform(0, 10, size=n)) |
19 | 53 | x_b = np.sort(np.random.uniform(0, 10, size=n)) |
20 | 54 |
|
21 | 55 | # Plot original distributions |
22 | 56 | plt.figure(figsize=(10, 2)) |
23 | | -plt.eventplot([x_a, x_b], lineoffsets=[1, -1], colors=["C0", "C1"], linelengths=0.6) |
24 | | -plt.yticks([1, -1], ["x_a", "x_b"]) |
25 | | -plt.title("Original 1D Discrete Distributions") |
26 | | -plt.grid(True) |
| 57 | +plot_partial_transport(plt.gca(), x_a, x_b) |
27 | 58 | plt.show() |
28 | 59 |
|
29 | 60 | # %% |
|
36 | 67 | fig, axes = plt.subplots(n, 1, figsize=(10, 2.2 * n), sharex=True) |
37 | 68 |
|
38 | 69 | for k, ax in enumerate(axes): |
39 | | - ax.eventplot([x_a, x_b], lineoffsets=[1, -1], colors=["C0", "C1"], linelengths=0.6) |
40 | | - ax.set_yticks([1, -1]) |
41 | | - ax.set_yticklabels(["x_a", "x_b"]) |
42 | | - ax.set_title( |
43 | | - f"Partial Transport - k = {k+1}, Cumulative Cost = {cumulative_costs[k]:.2f}" |
| 70 | + plot_partial_transport( |
| 71 | + ax, x_a, x_b, indices_a[: k + 1], indices_b[: k + 1], marginal_costs[: k + 1] |
44 | 72 | ) |
45 | | - ax.grid(True) |
46 | | - |
47 | | - subset_a = np.sort(x_a[indices_a[: k + 1]]) |
48 | | - subset_b = np.sort(x_b[indices_b[: k + 1]]) |
49 | | - |
50 | | - for x_a_i, x_b_j in zip(subset_a, subset_b): |
51 | | - ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7) |
52 | 73 |
|
53 | 74 | plt.tight_layout() |
54 | 75 | plt.show() |
0 commit comments