Skip to content

Commit cd53295

Browse files
committed
git import in doctests of sliced
1 parent 150824d commit cd53295

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

examples/sliced-wasserstein/plot_sliced_plans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
===============
66
77
Compares different Sliced OT plans between two 2D point clouds. The min-Pivot
8-
Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both
9-
were further studied theoretically in [83].
8+
Sliced plan was introduced in [83], and the Expected Sliced plan in [85], both
9+
were further studied theoretically in [84].
1010
1111
.. [83] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.
1212

ot/sliced/_sliced_distances.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def sliced_wasserstein_distance(
6868
Examples
6969
--------
7070
71+
>>> import ot
7172
>>> n_samples_a = 20
7273
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
73-
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
74+
>>> ot.sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
7475
0.0
7576
7677
References
@@ -174,9 +175,10 @@ def max_sliced_wasserstein_distance(
174175
Examples
175176
--------
176177
178+
>>> import ot
177179
>>> n_samples_a = 20
178180
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
179-
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
181+
>>> ot.sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
180182
0.0
181183
182184
References

ot/sliced/_sliced_plans.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,11 @@ def min_pivot_sliced(
249249
250250
Examples
251251
--------
252+
>>> import ot
252253
>>> x=np.array([[3.,3.], [1.,1.]])
253254
>>> y=np.array([[2.,2.5], [3.,2.]])
254255
>>> thetas=np.array([[1, 0], [0, 1]])
255-
>>> plan, cost = min_pivot_sliced(x, y, thetas=thetas)
256+
>>> plan, cost = ot.min_pivot_sliced(x, y, thetas=thetas)
256257
>>> plan
257258
array([[0. , 0.5],
258259
[0.5, 0. ]])
@@ -409,10 +410,11 @@ def expected_sliced(
409410
410411
Examples
411412
--------
413+
>>> import ot
412414
>>> x=np.array([[3.,3.], [1.,1.]])
413415
>>> y=np.array([[2.,2.5], [3.,2.]])
414416
>>> thetas=np.array([[1, 0], [0, 1]])
415-
>>> plan, cost = expected_sliced(x, y, thetas=thetas)
417+
>>> plan, cost = ot.expected_sliced(x, y, thetas=thetas)
416418
>>> plan
417419
array([[0.25, 0.25],
418420
[0.25, 0.25]])

ot/sliced/_spherical_sliced.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def sliced_wasserstein_sphere(
7070
7171
Examples
7272
--------
73+
>>> import ot
7374
>>> n_samples_a = 20
7475
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
7576
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
76-
>>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
77+
>>> ot.sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
7778
0.0
7879
7980
References
@@ -155,11 +156,12 @@ def sliced_wasserstein_sphere_unif(
155156
156157
Examples
157158
---------
159+
>>> import ot
158160
>>> np.random.seed(42)
159161
>>> x0 = np.random.randn(500,3)
160162
>>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
161-
>>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42)
162-
>>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
163+
>>> ssw = ot.sliced_wasserstein_sphere_unif(x0, seed=42)
164+
>>> np.allclose(ot.sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3)
163165
True
164166
165167
References:
@@ -240,10 +242,11 @@ def linear_sliced_wasserstein_sphere(
240242
241243
Examples
242244
---------
245+
>>> import ot
243246
>>> n_samples_a = 20
244247
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
245248
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
246-
>>> linear_sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
249+
>>> ot.linear_sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
247250
0.0
248251
249252

0 commit comments

Comments
 (0)