Skip to content

Commit fc3454f

Browse files
committed
improve coverage and fix doc
1 parent 48bd998 commit fc3454f

File tree

4 files changed

+22
-1
lines changed

4 files changed

+22
-1
lines changed

ot/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,8 @@ def randn(self, *size, type_as=None):
22412241
return torch.randn(size=size, generator=self.rng_)
22422242

22432243
def randperm(self, size, type_as=None):
2244+
if not isinstance(size, int):
2245+
raise ValueError("size must be an integer")
22442246
if type_as is not None:
22452247
generator = (
22462248
self.rng_cuda_ if self.device_type(type_as) == "GPU" else self.rng_

ot/bregman/_empirical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def empirical_sinkhorn_nystroem2(
936936
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
937937
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
938938
>>> empirical_sinkhorn_nystroem2(X_s, X_t, reg, anchors, random_state=42) # doctest: +ELLIPSIS
939-
1.9138489870270898
939+
2.5
940940
941941
942942
References

test/test_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def test_empty_backend():
200200
nx.seed(42)
201201
with pytest.raises(NotImplementedError):
202202
nx.rand()
203+
with pytest.raises(NotImplementedError):
204+
nx.randperm(12)
203205
with pytest.raises(NotImplementedError):
204206
nx.randn()
205207
nx.coo_matrix(M, M, M)
@@ -215,6 +217,8 @@ def test_empty_backend():
215217
nx.where(M, M, M)
216218
with pytest.raises(NotImplementedError):
217219
nx.copy(M)
220+
with pytest.raises(NotImplementedError):
221+
nx.pinv(M)
218222
with pytest.raises(NotImplementedError):
219223
nx.allclose(M, M)
220224
with pytest.raises(NotImplementedError):
@@ -759,6 +763,9 @@ def test_random_backends(nx):
759763
M4 = nx.sort(nx.randperm(5))
760764
assert np.allclose(M3, M4)
761765

766+
with pytest.raises(ValueError, match="size must be"):
767+
res = nx.randperm(size=[5, 12])
768+
762769

763770
def test_gradients_backends():
764771
rnd = np.random.RandomState(0)

test/test_lowrank.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,18 @@ def test_nystroem_sinkhorn2(log):
144144

145145
np.testing.assert_allclose(loss1, loss2, atol=1e-07, rtol=1e-3)
146146

147+
with pytest.raises(ValueError, match="anchors must"):
148+
res = ot.bregman.empirical_sinkhorn_nystroem2(
149+
Xs,
150+
Xt,
151+
anchors=1,
152+
reg=reg,
153+
numItermax=3000,
154+
verbose=True,
155+
random_state=random_state,
156+
log=log,
157+
)
158+
147159

148160
def test_compute_lr_sqeuclidean_matrix():
149161
# test computation of low rank cost matrices M1 and M2

0 commit comments

Comments
 (0)