Skip to content

Commit f089545

Browse files
committed
clean doc
1 parent 1716ab0 commit f089545

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

ot/sliced/_sliced_plans.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def sliced_plans(
3030
log=False,
3131
):
3232
r"""
33-
Computes all the permutations that sort the projections of two `(n, d)`
34-
datasets `X` and `Y` on the directions `thetas`.
35-
Each permutation `perm[:, k]` is such that each `X[i, :]` is matched
36-
to `Y[perm[i, k], :]` when projected on `thetas[k, :]`.
33+
Computes all the permutations that sort the projections of two `(ns, nt)`
34+
datasets `X_s` and `X_t` on the directions `thetas`.
35+
Each permutation `perm[:, k]` is such that each :math:`X_s[i, :]` is matched
36+
to `X_t[perm[i, k], :]` when projected on `thetas[k, :]`.
3737
3838
Parameters
3939
----------
@@ -162,8 +162,8 @@ def sliced_plans(
162162

163163

164164
def min_pivot_sliced(
165-
X,
166-
Y,
165+
X_s,
166+
X_t,
167167
a=None,
168168
b=None,
169169
thetas=None,
@@ -177,19 +177,19 @@ def min_pivot_sliced(
177177
r"""
178178
Computes the cost and permutation associated to the min-Pivot Sliced
179179
Discrepancy (introduced as SWGG in [83] and studied further in [84]). Given
180-
the supports `X` and `Y` of two discrete uniform measures with `n` and `m`
180+
the supports `X_s` and `X_t` of two discrete uniform measures with `ns` and `nt`
181181
atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through
182182
`n_proj` different projections of the measures on random directions, and
183-
retains the couplings that yields the lowest cost between `X` and `Y`
184-
(compared in :math:`\mathbb{R}^d`). When $n=m$, it gives
183+
retains the couplings that yields the lowest cost between `X_s` and `X_t`
184+
(compared in :math:`\mathbb{R}^d`). When `ns=nt`, it gives
185185
186186
.. math::
187-
\mathrm{min\text{-}PS}_p^p(X, Y) \approx
187+
\mathrm{min\text{-}PS}_p^p(X_s, X_t) \approx
188188
\min_{k \in [1, n_{\mathrm{proj}}]} \left(
189-
\frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\sigma_k(i)}\|_2^p \right),
189+
\frac{1}{n_s} \sum_{i=1}^{n_s} \|X_{s,i} - X_{t,\sigma_k(i)}\|_2^p \right),
190190
191191
where :math:`\sigma_k` is a permutation such that ordering the projections
192-
on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`.
192+
on the axis `thetas[k, :]` matches :math:`X_s[i, :]` to :math:`X_t[\sigma_k(i), :]`.
193193
194194
.. note::
195195
The computation ignores potential ambiguities in the projections: if
@@ -198,15 +198,18 @@ def min_pivot_sliced(
198198
explosion, only one permutation is retained: this strays from theory in
199199
pathological cases.
200200
201+
.. warning::
202+
Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
203+
201204
Parameters
202205
----------
203-
X : array-like, shape (n, d)
206+
X_s : array-like, shape (ns, d)
204207
The first set of vectors.
205-
Y : array-like, shape (m, d)
208+
X_t : array-like, shape (nt, d)
206209
The second set of vectors.
207-
a : ndarray of float64, shape (n,), optional
210+
a : ndarray of float64, shape (ns,), optional
208211
Source histogram (default is uniform weight)
209-
b : ndarray of float64, shape (m,), optional
212+
b : ndarray of float64, shape (nt,), optional
210213
Target histogram (default is uniform weight)
211214
thetas : array-like, shape (n_proj, d), optional
212215
The projection directions. If None, random directions will be generated
@@ -262,32 +265,31 @@ def min_pivot_sliced(
262265
2.125
263266
"""
264267

265-
X, Y = list_to_array(X, Y)
268+
X_s, X_t = list_to_array(X_s, X_t)
266269

267270
if a is not None and b is not None and thetas is None:
268-
nx = get_backend(X, Y, a, b)
271+
nx = get_backend(X_s, X_t, a, b)
269272
elif a is not None and b is not None and thetas is not None:
270-
nx = get_backend(X, Y, a, b, thetas)
273+
nx = get_backend(X_s, X_t, a, b, thetas)
271274
elif a is None and b is None and thetas is not None:
272-
nx = get_backend(X, Y, thetas)
275+
nx = get_backend(X_s, X_t, thetas)
273276
else:
274-
nx = get_backend(X, Y)
275-
276-
assert X.ndim == 2, f"X must be a 2d array, got {X.ndim}d array instead"
277-
assert Y.ndim == 2, f"Y must be a 2d array, got {Y.ndim}d array instead"
277+
nx = get_backend(X_s, X_t)
278+
assert X_s.ndim == 2, f"X_s must be a 2d array, got {X_s.ndim}d array instead"
279+
assert X_t.ndim == 2, f"X_t must be a 2d array, got {X_t.ndim}d array instead"
278280

279281
assert (
280-
X.shape[1] == Y.shape[1]
281-
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
282+
X_s.shape[1] == X_t.shape[1]
283+
), f"X_s ({X_s.shape}) and X_t ({X_t.shape}) must have the same number of columns"
282284

283285
if str(nx) in ["tf", "jax"] and not dense:
284286
dense = True
285287
warnings.warn("JAX and TF do not support sparse matrices, converting to dense")
286288

287289
log_dict = {}
288290
G, costs, log_dict_plans = sliced_plans(
289-
X,
290-
Y,
291+
X_s,
292+
X_t,
291293
a,
292294
b,
293295
metric,
@@ -316,8 +318,8 @@ def min_pivot_sliced(
316318
plan["data"],
317319
plan["rows"],
318320
plan["cols"],
319-
shape=(X.shape[0], Y.shape[0]),
320-
type_as=X,
321+
shape=(X_s.shape[0], X_t.shape[0]),
322+
type_as=X_s,
321323
)
322324

323325
if dense:
@@ -343,10 +345,10 @@ def expected_sliced(
343345
beta=0.0,
344346
):
345347
r"""
346-
Computes the Expected Sliced cost and plan between two datasets `X` and
347-
`Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection
348+
Computes the Expected Sliced cost and plan between two datasets `X_s` and
349+
`X_t` of shapes `(ns, d)` and `(nt, d)`. Given a set of `n_proj` projection
348350
directions, the expected sliced plan is obtained by averaging the `n_proj`
349-
1d optimal transport plans between the projections of `X` and `Y` on each
351+
1d optimal transport plans between the projections of `X_s` and `X_t` on each
350352
direction. Expected Sliced was introduced in [85] and further studied in
351353
[84].
352354
@@ -358,8 +360,7 @@ def expected_sliced(
358360
pathological cases.
359361
360362
.. warning::
361-
The function runs on backend but tensorflow and jax are not supported
362-
due to array assignment.
363+
Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
363364
364365
Parameters
365366
----------

0 commit comments

Comments
 (0)