Skip to content

Commit 0534326

Browse files
committed
remove multiple targets from low rank sinkhorn + little more coverage
1 parent 875ed4a commit 0534326

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

ot/lowrank.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def sinkhorn_low_rank_kernel(
600600
):
601601
r"""
602602
Compute the Sinkhorn algorithm for a kernel :math:`\mathbf{K}` that can be written as a low rank factorization :math:`\mathbf{K} = \mathbf{K}_1 \mathbf{K}_2^\top`.
603+
Does not implement multiple targets.
603604
604605
Precisely :
605606
@@ -620,9 +621,8 @@ def sinkhorn_low_rank_kernel(
620621
Right factor
621622
a : array-like, shape (n_samples_a,)
622623
samples weights in the source domain
623-
b : array-like, shape (n_samples_b,) or array-like, shape (n_samples_b, n_hists)
624-
samples in the target domain, compute sinkhorn with multiple targets
625-
if :math:`\mathbf{b}` is a matrix
624+
b : array-like, shape (n_samples_b,)
625+
samples in the target domain
626626
numItermax : int, optional
627627
Max number of iterations
628628
stopThr : float, optional
@@ -639,9 +639,9 @@ def sinkhorn_low_rank_kernel(
639639
640640
Returns
641641
---------
642-
u : array-like, shape (n_samples_a, ) or array-like, shape (n_samples_a, n_hists)
642+
u : array-like, shape (n_samples_a, )
643643
Left dual variable
644-
v: array-like, shape (n_samples_b, ) or array-like, shape (n_samples_b, n_hists)
644+
v: array-like, shape (n_samples_b, )
645645
Right dual variable
646646
log : dict (lazy_plan)
647647
log dictionary return only if log==True in parameters
@@ -659,23 +659,14 @@ def sinkhorn_low_rank_kernel(
659659
dim_a = len(a)
660660
dim_b = b.shape[0]
661661

662-
if len(b.shape) > 1:
663-
n_hists = b.shape[1]
664-
else:
665-
n_hists = 0
666-
667662
if log:
668663
dict_log = {"err": []}
669664

670665
# we assume that no distances are null except those of the diagonal of
671666
# distances
672667
if warmstart is None:
673-
if n_hists:
674-
u = nx.ones((dim_a, n_hists), type_as=K1) / dim_a
675-
v = nx.ones((dim_b, n_hists), type_as=K2) / dim_b
676-
else:
677-
u = nx.ones(dim_a, type_as=K1) / dim_a
678-
v = nx.ones(dim_b, type_as=K2) / dim_b
668+
u = nx.ones(dim_a, type_as=K1) / dim_a
669+
v = nx.ones(dim_b, type_as=K2) / dim_b
679670
else:
680671
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
681672

test/test_lowrank.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def test_nystroem_kernel_approx():
2424
np.testing.assert_allclose(K, U @ V.T, atol=1e-7)
2525

2626

27-
@pytest.mark.parametrize("log", [False, True])
28-
def test_nystroem_sinkhorn(log):
27+
@pytest.mark.parametrize("log, warn", [[False, False], [True, True]])
28+
def test_nystroem_sinkhorn(log, warn):
2929
# test Nystrom approximation for Sinkhorn (ot plan)
3030
offset = 2
3131
n_samples_per_blob = 50
@@ -68,6 +68,7 @@ def test_nystroem_sinkhorn(log):
6868
verbose=True,
6969
random_state=random_state,
7070
log=log,
71+
warn=warn,
7172
)
7273
if log:
7374
G_nys, log_ = res

0 commit comments

Comments
 (0)