Skip to content

Commit 3e362fe

Browse files
committed
correct import in doctests of sliced
1 parent cd53295 commit 3e362fe

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

ot/sliced/_sliced_distances.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def sliced_wasserstein_distance(
6969
--------
7070
7171
>>> import ot
72+
>>> import numpy as np
7273
>>> n_samples_a = 20
7374
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
7475
>>> ot.sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
@@ -176,6 +177,7 @@ def max_sliced_wasserstein_distance(
176177
--------
177178
178179
>>> import ot
180+
>>> import numpy as np
179181
>>> n_samples_a = 20
180182
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
181183
>>> ot.sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE

ot/sliced/_sliced_plans.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def min_pivot_sliced(
250250
Examples
251251
--------
252252
>>> import ot
253+
>>> import numpy as np
253254
>>> x=np.array([[3.,3.], [1.,1.]])
254255
>>> y=np.array([[2.,2.5], [3.,2.]])
255256
>>> thetas=np.array([[1, 0], [0, 1]])
@@ -411,6 +412,7 @@ def expected_sliced(
411412
Examples
412413
--------
413414
>>> import ot
415+
>>> import numpy as np
414416
>>> x=np.array([[3.,3.], [1.,1.]])
415417
>>> y=np.array([[2.,2.5], [3.,2.]])
416418
>>> thetas=np.array([[1, 0], [0, 1]])
@@ -448,7 +450,7 @@ def expected_sliced(
448450

449451
log_dict = {}
450452
G, costs, log_dict_plans = sliced_plans(
451-
X_s, X_t, a, b, metric, p, thetas, n_proj=n_proj, log=True
453+
X_s, X_t, a, b, metric, p, thetas, n_proj=n_proj, log=log
452454
)
453455

454456
if beta != 0.0: # computing the temperature weighting

ot/sliced/_spherical_sliced.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def sliced_wasserstein_sphere(
7171
Examples
7272
--------
7373
>>> import ot
74+
>>> import numpy as np
7475
>>> n_samples_a = 20
7576
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
7677
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))
@@ -157,6 +158,7 @@ def sliced_wasserstein_sphere_unif(
157158
Examples
158159
---------
159160
>>> import ot
161+
>>> import numpy as np
160162
>>> np.random.seed(42)
161163
>>> x0 = np.random.randn(500,3)
162164
>>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True))
@@ -243,6 +245,7 @@ def linear_sliced_wasserstein_sphere(
243245
Examples
244246
---------
245247
>>> import ot
248+
>>> import numpy as np
246249
>>> n_samples_a = 20
247250
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
248251
>>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True))

0 commit comments

Comments
 (0)