Skip to content

Commit ff42483

Browse files
committed
minor refactor
1 parent 843324c commit ff42483

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

examples/others/plot_partial_1d.py renamed to examples/unbalanced-partial/plot_partial_1d.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
Partial Wasserstein 1D - Gallery Example
2+
=========================
3+
Partial Wasserstein in 1D
4+
=========================
35
46
This script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.
57
@@ -25,7 +27,6 @@
2527
plt.show()
2628

2729
# %%
28-
# Run the function on our example data
2930
indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)
3031

3132
# Compute cumulative cost
@@ -43,8 +44,11 @@
4344
)
4445
ax.grid(True)
4546

46-
for i, j in zip(indices_a[: k + 1], indices_b[: k + 1]):
47-
ax.plot([x_a[i], x_b[j]], [1, -1], "k--", alpha=0.7)
47+
subset_a = np.sort(x_a[indices_a[: k + 1]])
48+
subset_b = np.sort(x_b[indices_b[: k + 1]])
49+
50+
for x_a_i, x_b_j in zip(subset_a, subset_b):
51+
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)
4852

4953
plt.tight_layout()
5054
plt.show()

0 commit comments

Comments
 (0)