@@ -30,10 +30,10 @@ def sliced_plans(
3030 log = False ,
3131):
3232 r"""
33- Computes all the permutations that sort the projections of two `(n, d )`
34- datasets `X ` and `Y ` on the directions `thetas`.
35- Each permutation `perm[:, k]` is such that each `X [i, :]` is matched
36- to `Y [perm[i, k], :]` when projected on `thetas[k, :]`.
33+ Computes all the permutations that sort the projections of two `(ns, nt )`
34+ datasets `X_s ` and `X_t ` on the directions `thetas`.
35+ Each permutation `perm[:, k]` is such that each :math:`X_s [i, :]` is matched
36+ to `X_t [perm[i, k], :]` when projected on `thetas[k, :]`.
3737
3838 Parameters
3939 ----------
@@ -162,8 +162,8 @@ def sliced_plans(
162162
163163
164164def min_pivot_sliced (
165- X ,
166- Y ,
165+ X_s ,
166+ X_t ,
167167 a = None ,
168168 b = None ,
169169 thetas = None ,
@@ -177,19 +177,19 @@ def min_pivot_sliced(
177177 r"""
178178 Computes the cost and permutation associated to the min-Pivot Sliced
179179 Discrepancy (introduced as SWGG in [83] and studied further in [84]). Given
180- the supports `X ` and `Y ` of two discrete uniform measures with `n ` and `m `
180+ the supports `X_s ` and `X_t ` of two discrete uniform measures with `ns ` and `nt `
181181 atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through
182182 `n_proj` different projections of the measures on random directions, and
183- retains the couplings that yields the lowest cost between `X ` and `Y `
184- (compared in :math:`\mathbb{R}^d`). When $n=m$ , it gives
183+ retains the couplings that yields the lowest cost between `X_s ` and `X_t `
184+ (compared in :math:`\mathbb{R}^d`). When `ns=nt` , it gives
185185
186186 .. math::
187- \mathrm{min\text{-}PS}_p^p(X, Y ) \approx
187+ \mathrm{min\text{-}PS}_p^p(X_s, X_t ) \approx
188188 \min_{k \in [1, n_{\mathrm{proj}}]} \left(
189- \frac{1}{n } \sum_{i=1}^n \|X_i - Y_{ \sigma_k(i)}\|_2^p \right),
189+ \frac{1}{n_s } \sum_{i=1}^{n_s} \|X_{s,i} - X_{t, \sigma_k(i)}\|_2^p \right),
190190
191191 where :math:`\sigma_k` is a permutation such that ordering the projections
192- on the axis `thetas[k, :]` matches `X [i, :]` to `Y [\sigma_k(i), :]`.
192+ on the axis `thetas[k, :]` matches :math:`X_s [i, :]` to :math:`X_t [\sigma_k(i), :]`.
193193
194194 .. note::
195195 The computation ignores potential ambiguities in the projections: if
@@ -198,15 +198,18 @@ def min_pivot_sliced(
198198 explosion, only one permutation is retained: this strays from theory in
199199 pathological cases.
200200
201+ .. warning::
202+ Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
203+
201204 Parameters
202205 ----------
203- X : array-like, shape (n , d)
206+ X_s : array-like, shape (ns , d)
204207 The first set of vectors.
205- Y : array-like, shape (m , d)
208+ X_t : array-like, shape (nt , d)
206209 The second set of vectors.
207- a : ndarray of float64, shape (n ,), optional
210+ a : ndarray of float64, shape (ns ,), optional
208211 Source histogram (default is uniform weight)
209- b : ndarray of float64, shape (m ,), optional
212+ b : ndarray of float64, shape (nt ,), optional
210213 Target histogram (default is uniform weight)
211214 thetas : array-like, shape (n_proj, d), optional
212215 The projection directions. If None, random directions will be generated
@@ -262,32 +265,31 @@ def min_pivot_sliced(
262265 2.125
263266 """
264267
265- X , Y = list_to_array (X , Y )
268+ X_s , X_t = list_to_array (X_s , X_t )
266269
267270 if a is not None and b is not None and thetas is None :
268- nx = get_backend (X , Y , a , b )
271+ nx = get_backend (X_s , X_t , a , b )
269272 elif a is not None and b is not None and thetas is not None :
270- nx = get_backend (X , Y , a , b , thetas )
273+ nx = get_backend (X_s , X_t , a , b , thetas )
271274 elif a is None and b is None and thetas is not None :
272- nx = get_backend (X , Y , thetas )
275+ nx = get_backend (X_s , X_t , thetas )
273276 else :
274- nx = get_backend (X , Y )
275-
276- assert X .ndim == 2 , f"X must be a 2d array, got { X .ndim } d array instead"
277- assert Y .ndim == 2 , f"Y must be a 2d array, got { Y .ndim } d array instead"
277+ nx = get_backend (X_s , X_t )
278+ assert X_s .ndim == 2 , f"X_s must be a 2d array, got { X_s .ndim } d array instead"
279+ assert X_t .ndim == 2 , f"X_t must be a 2d array, got { X_t .ndim } d array instead"
278280
279281 assert (
280- X .shape [1 ] == Y .shape [1 ]
281- ), f"X ({ X .shape } ) and Y ({ Y .shape } ) must have the same number of columns"
282+ X_s .shape [1 ] == X_t .shape [1 ]
283+ ), f"X_s ({ X_s .shape } ) and X_t ({ X_t .shape } ) must have the same number of columns"
282284
283285 if str (nx ) in ["tf" , "jax" ] and not dense :
284286 dense = True
285287 warnings .warn ("JAX and TF do not support sparse matrices, converting to dense" )
286288
287289 log_dict = {}
288290 G , costs , log_dict_plans = sliced_plans (
289- X ,
290- Y ,
291+ X_s ,
292+ X_t ,
291293 a ,
292294 b ,
293295 metric ,
@@ -316,8 +318,8 @@ def min_pivot_sliced(
316318 plan ["data" ],
317319 plan ["rows" ],
318320 plan ["cols" ],
319- shape = (X .shape [0 ], Y .shape [0 ]),
320- type_as = X ,
321+ shape = (X_s .shape [0 ], X_t .shape [0 ]),
322+ type_as = X_s ,
321323 )
322324
323325 if dense :
@@ -343,10 +345,10 @@ def expected_sliced(
343345 beta = 0.0 ,
344346):
345347 r"""
346- Computes the Expected Sliced cost and plan between two datasets `X ` and
347- `Y ` of shapes `(n , d)` and `(m , d)`. Given a set of `n_proj` projection
348+ Computes the Expected Sliced cost and plan between two datasets `X_s ` and
349+ `X_t ` of shapes `(ns , d)` and `(nt , d)`. Given a set of `n_proj` projection
348350 directions, the expected sliced plan is obtained by averaging the `n_proj`
349- 1d optimal transport plans between the projections of `X ` and `Y ` on each
351+ 1d optimal transport plans between the projections of `X_s ` and `X_t ` on each
350352 direction. Expected Sliced was introduced in [85] and further studied in
351353 [84].
352354
@@ -358,8 +360,7 @@ def expected_sliced(
358360 pathological cases.
359361
360362 .. warning::
361- The function runs on backend but tensorflow and jax are not supported
362- due to array assignment.
363+ Tensorflow and jax only returns dense plans, as they do not support well sparse matrices.
363364
364365 Parameters
365366 ----------
0 commit comments