Skip to content

Commit bbc930b

Browse files
committed
define a function for plotting
1 parent ec7d9c1 commit bbc930b

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

examples/unbalanced-partial/plot_partial_1d.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,49 @@
1212
import matplotlib.pyplot as plt
1313
from ot.partial import partial_wasserstein_1d
1414

15+
16+
def plot_partial_transport(
17+
ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None
18+
):
19+
y_a = np.ones_like(x_a)
20+
y_b = -np.ones_like(x_b)
21+
22+
# Plot all points
23+
ax.plot(x_a, y_a, "o", color="C0", label="x_a")
24+
ax.plot(x_b, y_b, "o", color="C1", label="x_b")
25+
26+
# Plot transport lines
27+
if indices_a is not None and indices_b is not None:
28+
subset_a = np.sort(x_a[indices_a])
29+
subset_b = np.sort(x_b[indices_b])
30+
31+
for x_a_i, x_b_j in zip(subset_a, subset_b):
32+
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)
33+
34+
if marginal_costs is not None:
35+
k = len(marginal_costs)
36+
ax.set_title(
37+
f"Partial Transport - k = {k}, Cumulative Cost = {marginal_costs.sum():.2f}"
38+
)
39+
else:
40+
ax.set_title("Original 1D Discrete Distributions")
41+
ax.legend(loc="upper right")
42+
ax.set_yticks([])
43+
ax.set_xticks([])
44+
ax.set_ylim(-2, 2)
45+
ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)
46+
ax.axis("off")
47+
48+
1549
# Simulate two 1D discrete distributions
16-
np.random.seed(42)
50+
np.random.seed(0)
1751
n = 6
1852
x_a = np.sort(np.random.uniform(0, 10, size=n))
1953
x_b = np.sort(np.random.uniform(0, 10, size=n))
2054

2155
# Plot original distributions
2256
plt.figure(figsize=(10, 2))
23-
plt.eventplot([x_a, x_b], lineoffsets=[1, -1], colors=["C0", "C1"], linelengths=0.6)
24-
plt.yticks([1, -1], ["x_a", "x_b"])
25-
plt.title("Original 1D Discrete Distributions")
26-
plt.grid(True)
57+
plot_partial_transport(plt.gca(), x_a, x_b)
2758
plt.show()
2859

2960
# %%
@@ -36,19 +67,9 @@
3667
fig, axes = plt.subplots(n, 1, figsize=(10, 2.2 * n), sharex=True)
3768

3869
for k, ax in enumerate(axes):
39-
ax.eventplot([x_a, x_b], lineoffsets=[1, -1], colors=["C0", "C1"], linelengths=0.6)
40-
ax.set_yticks([1, -1])
41-
ax.set_yticklabels(["x_a", "x_b"])
42-
ax.set_title(
43-
f"Partial Transport - k = {k+1}, Cumulative Cost = {cumulative_costs[k]:.2f}"
70+
plot_partial_transport(
71+
ax, x_a, x_b, indices_a[: k + 1], indices_b[: k + 1], marginal_costs[: k + 1]
4472
)
45-
ax.grid(True)
46-
47-
subset_a = np.sort(x_a[indices_a[: k + 1]])
48-
subset_b = np.sort(x_b[indices_b[: k + 1]])
49-
50-
for x_a_i, x_b_j in zip(subset_a, subset_b):
51-
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)
5273

5374
plt.tight_layout()
5475
plt.show()

0 commit comments

Comments
 (0)