Skip to content

Commit f3030f2

Browse files
committed
remove dependance on coo_matrix and add sparse_ot_dist function into utils
1 parent 659e5fc commit f3030f2

File tree

7 files changed

+203
-136
lines changed

7 files changed

+203
-136
lines changed

.github/CONTRIBUTING.md

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ GitHub, clone, and develop on a branch. Steps:
2020
$ cd POT
2121
```
2222

23-
3. Install pre-commit hooks to ensure that your code is properly formatted:
23+
3. Install a recent version of Python. Using an isolated environment such as venv or conda allows you to install a specific version of POT.
24+
For instance, for creating a conda environment with python 3.12 and for activating it:
25+
26+
```bash
27+
$ conda create -n dev-pot-env python=3.12
28+
$ conda activate dev-pot-env
29+
```
30+
31+
4. Install pre-commit hooks to ensure that your code is properly formatted:
2432

2533
```bash
2634
$ pip install pre-commit
@@ -29,27 +37,14 @@ GitHub, clone, and develop on a branch. Steps:
2937

3038
This will install the pre-commit hooks that will run on every commit. If the hooks fail, the commit will be aborted.
3139

32-
4. Create a `feature` branch to hold your development changes:
40+
5. Create a `feature` branch to hold your development changes:
3341

3442
```bash
3543
$ git checkout -b my-feature
3644
```
3745

3846
Always use a `feature` branch. It's good practice to never work on the `master` branch!
3947

40-
5. Install a recent version of Python (e.g. 3.10), using conda for instance. You can create a conda environment and activate it:
41-
42-
```bash
43-
$ conda create -n dev-pot-env python=3.10
44-
$ conda activate dev-pot-env
45-
```
46-
47-
6. Install all the necessary packages in your environment:
48-
49-
```bash
50-
$ pip install -r requirements_all.txt
51-
```
52-
5348
6. Install a compiler with OpenMP support for your platform (see details on the [scikit-learn contributing guide](https://scikit-learn.org/stable/developers/advanced_installation.html#platform-specific-instructions)).
5449
For instance, with macOS, Apple clang does not support OpenMP. One can install the LLVM OpenMP library from homebrew:
5550

@@ -70,6 +65,12 @@ $ pip install -r requirements_all.txt
7065
pip install -e .
7166
```
7267

68+
If you want to install all dependencies, you can use
69+
70+
```bash
71+
pip install -e .[all]
72+
```
73+
7374
8. Develop the feature on your feature branch. Add changed files using `git add` and then `git commit` files:
7475

7576
```bash

examples/sliced-wasserstein/plot_sliced_plans.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both
99
were further studied theoretically in [83].
1010
11-
.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.
11+
.. [83] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.
1212
13-
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661.
13+
.. [84] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661.
1414
15-
.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations.
15+
.. [85] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations.
1616
"""
1717

1818
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>

ot/lp/solver_1d.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,10 @@ def wasserstein_1d(
102102
-------
103103
cost: float/array-like, shape (...)
104104
the batched EMD
105-
plan: list of coo_matrix, optional
105+
plan: list of dictionaries, optional
106106
if return_plan is True, returns the list of the optimal transport plans
107-
between the two (batched) measures as a coo_matrix, default is False
107+
as a list of dictionaries containing the rows, cols and data of the non-zero elements of the transportation matrix.
108+
Default is False
108109
109110
References
110111
----------
@@ -167,15 +168,24 @@ def wasserstein_1d(
167168
u_quantiles_idx = nx.take_along_axis(u_sorter, idx_u, axis=0)
168169
v_quantiles_idx = nx.take_along_axis(v_sorter, idx_v, axis=0)
169170
plan = [
170-
nx.coo_matrix(
171-
delta[:, k],
172-
u_quantiles_idx[:, k],
173-
v_quantiles_idx[:, k],
174-
shape=(n, m),
175-
type_as=u_values,
176-
)
171+
{
172+
"rows": u_quantiles_idx[:, k],
173+
"cols": v_quantiles_idx[:, k],
174+
"data": delta[:, k],
175+
}
177176
for k in range(delta.shape[1])
178177
]
178+
# plan = [
179+
# nx.coo_matrix(
180+
# delta[:, k],
181+
# u_quantiles_idx[:, k],
182+
# v_quantiles_idx[:, k],
183+
# shape=(n, m),
184+
# type_as=u_values,
185+
# )
186+
# for k in range(delta.shape[1])
187+
# ]
188+
179189
if p == 1:
180190
w_1d = nx.sum(delta * diff_quantiles, axis=0)
181191
else:

ot/sliced.py

Lines changed: 56 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import numpy as np
1717
from .backend import get_backend, NumpyBackend
18-
from .utils import list_to_array, get_coordinate_circle, dist
18+
from .utils import list_to_array, get_coordinate_circle, dist, sparse_ot_dist
1919
from .lp import (
2020
wasserstein_circle,
2121
semidiscrete_wasserstein2_unif_circle,
@@ -746,6 +746,10 @@ def sliced_plans(
746746
X.shape[1] == Y.shape[1]
747747
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
748748

749+
if str(nx) in ["tf", "jax"] and not dense:
750+
dense = True
751+
warnings.warn("JAX and TF do not support sparse matrices, converting to dense")
752+
749753
if metric == "euclidean":
750754
p = 2
751755
elif metric == "cityblock":
@@ -770,12 +774,6 @@ def sliced_plans(
770774
else:
771775
n_proj = thetas.shape[0]
772776

773-
def dist(i, j):
774-
if metric == "sqeuclidean":
775-
return nx.sum((X[i] - Y[j]) ** 2, axis=1)
776-
else:
777-
return nx.sum(nx.abs(X[i] - Y[j]) ** p, axis=1) ** (1 / p)
778-
779777
# project on each theta: (n or m, d) -> (n or m, n_proj)
780778
X_theta = X @ thetas.T # shape (n, n_proj)
781779
Y_theta = Y @ thetas.T # shape (m, n_proj)
@@ -784,66 +782,30 @@ def dist(i, j):
784782
# sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj]
785783
sigma = nx.argsort(X_theta, axis=0) # (n, n_proj)
786784
tau = nx.argsort(Y_theta, axis=0) # (m, n_proj)
787-
788-
costs = [nx.sum(dist(sigma[:, k], tau[:, k]) / n) for k in range(n_proj)]
789-
785+
costs = [
786+
sparse_ot_dist(X, Y, sigma[:, k], tau[:, k], metric=metric, p=p)
787+
for k in range(n_proj)
788+
]
790789
a = nx.ones(n) / n
791790
plan = [
792-
nx.coo_matrix(a, sigma[:, k], tau[:, k], shape=(n, m), type_as=a)
793-
for k in range(n_proj)
791+
{"rows": sigma[:, k], "cols": tau[:, k], "data": a} for k in range(n_proj)
794792
]
795-
796793
else: # we compute plans
797794
_, plan = wasserstein_1d(
798795
X_theta, Y_theta, a, b, p, require_sort=True, return_plan=True
799796
)
800-
801-
if str(nx) in ["tf", "jax"]:
802-
if not dense:
803-
if str(nx) == "jax":
804-
warnings.warn(
805-
"JAX does not support sparse matrices, converting to dense"
806-
)
807-
else:
808-
warnings.warn(
809-
"TensorFlow multiple indexing is forbidden, converting to dense"
810-
)
811-
plan_dense = [nx.todense(plan[k]) for k in range(n_proj)]
812-
idx_non_zeros = [nx.where(plan_dense[k] != 0) for k in range(n_proj)]
813-
costs = [
814-
nx.sum(
815-
dist(idx_non_zeros[k][0], idx_non_zeros[k][1])
816-
* plan_dense[k][idx_non_zeros[k][0], idx_non_zeros[k][1]]
817-
)
818-
for k in range(n_proj)
819-
]
820-
else:
821-
if str(nx) == "torch":
822-
plan = [plan[k].coalesce() for k in range(n_proj)]
823-
costs = [
824-
nx.sum(
825-
dist(plan[k].indices()[0], plan[k].indices()[1])
826-
* plan[k].values()
827-
)
828-
for k in range(n_proj)
829-
]
830-
else:
831-
costs = [
832-
nx.sum(dist(plan[k].row, plan[k].col) * plan[k].data)
833-
for k in range(n_proj)
834-
]
835-
836-
if dense and str(nx) not in ["tf", "jax"]:
837-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
838-
elif str(nx) in ["tf", "jax"]:
839-
if not is_perm:
840-
warnings.warn(
841-
"JAX and tensorflow do not support well sparse "
842-
"matrices, converting to dense"
797+
costs = [
798+
sparse_ot_dist(
799+
X,
800+
Y,
801+
plan[k]["rows"],
802+
plan[k]["cols"],
803+
plan[k]["data"],
804+
metric=metric,
805+
p=p,
843806
)
844-
plan = [nx.todense(plan[k]) for k in range(n_proj)]
845-
else:
846-
plan = plan_dense.copy()
807+
for k in range(n_proj)
808+
]
847809

848810
if log:
849811
log_dict = {"X_theta": X_theta, "Y_theta": Y_theta, "thetas": thetas}
@@ -867,7 +829,7 @@ def min_pivot_sliced(
867829
):
868830
r"""
869831
Computes the cost and permutation associated to the min-Pivot Sliced
870-
Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given
832+
Discrepancy (introduced as SWGG in [83] and studied further in [84]). Given
871833
the supports `X` and `Y` of two discrete uniform measures with `n` and `m`
872834
atoms in dimension `d`, the min-Pivot Sliced Discrepancy goes through
873835
`n_proj` different projections of the measures on random directions, and
@@ -930,12 +892,12 @@ def min_pivot_sliced(
930892
931893
References
932894
----------
933-
.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023).
895+
.. [83] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023).
934896
Fast Optimal Transport through Sliced Generalized Wasserstein
935897
Geodesics. Advances in Neural Information Processing Systems, 36,
936898
35350–35385.
937899
938-
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
900+
.. [84] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
939901
Plans. arXiv preprint 2506.03661.
940902
941903
Examples
@@ -969,6 +931,10 @@ def min_pivot_sliced(
969931
X.shape[1] == Y.shape[1]
970932
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
971933

934+
if str(nx) in ["tf", "jax"] and not dense:
935+
dense = True
936+
warnings.warn("JAX and TF do not support sparse matrices, converting to dense")
937+
972938
log_dict = {}
973939
G, costs, log_dict_plans = sliced_plans(
974940
X,
@@ -996,11 +962,17 @@ def min_pivot_sliced(
996962
"Y_min_theta": log_dict_plans["Y_theta"][:, pos_min],
997963
}
998964

965+
# get the plan from the indices of the non-zero entries of the sparse plan
966+
plan = nx.coo_matrix(
967+
plan["data"],
968+
plan["rows"],
969+
plan["cols"],
970+
shape=(X.shape[0], Y.shape[0]),
971+
type_as=X,
972+
)
973+
999974
if dense:
1000975
plan = nx.todense(plan)
1001-
elif str(nx) in ["tf", "jax"]:
1002-
warnings.warn("JAX and TF do not support sparse matrices, converting to dense")
1003-
plan = nx.todense(plan)
1004976

1005977
if log:
1006978
return plan, cost, log_dict
@@ -1026,8 +998,8 @@ def expected_sliced(
1026998
`Y` of shapes `(n, d)` and `(m, d)`. Given a set of `n_proj` projection
1027999
directions, the expected sliced plan is obtained by averaging the `n_proj`
10281000
1d optimal transport plans between the projections of `X` and `Y` on each
1029-
direction. Expected Sliced was introduced in [84] and further studied in
1030-
[83].
1001+
direction. Expected Sliced was introduced in [85] and further studied in
1002+
[84].
10311003
10321004
.. note::
10331005
The computation ignores potential ambiguities in the projections: if
@@ -1082,9 +1054,9 @@ def expected_sliced(
10821054
10831055
References
10841056
----------
1085-
.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
1057+
.. [84] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport
10861058
Plans. arXiv preprint 2506.03661.
1087-
.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi
1059+
.. [85] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi
10881060
A., Kolouri, S. (2024). Expected Sliced Transport Plans.
10891061
International Conference on Learning Representations.
10901062
@@ -1118,11 +1090,9 @@ def expected_sliced(
11181090
X.shape[1] == Y.shape[1]
11191091
), f"X ({X.shape}) and Y ({Y.shape}) must have the same number of columns"
11201092

1121-
if str(nx) in ["tf", "jax"]:
1122-
raise NotImplementedError(
1123-
f"expected_sliced is not implemented for the {str(nx)} backend due"
1124-
"to array assignment."
1125-
)
1093+
if str(nx) in ["tf", "jax"] and not dense:
1094+
dense = True
1095+
warnings.warn("JAX and TF do not support sparse matrices, converting to dense")
11261096

11271097
n = X.shape[0]
11281098
m = Y.shape[0]
@@ -1131,40 +1101,34 @@ def expected_sliced(
11311101
G, costs, log_dict_plans = sliced_plans(
11321102
X, Y, a, b, metric, p, thetas, n_proj=n_proj, log=True, dense=False
11331103
)
1134-
if log:
1135-
log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G}
11361104

11371105
if beta != 0.0: # computing the temperature weighting
11381106
log_factors = -beta * list_to_array(costs)
11391107
weights = nx.exp(log_factors - nx.logsumexp(log_factors))
11401108
cost = nx.sum(list_to_array(costs) * weights)
1141-
11421109
else: # uniform weights
11431110
if n_proj is None:
11441111
n_proj = thetas.shape[0]
11451112
weights = nx.ones(n_proj) / n_proj
11461113

1147-
log_dict["weights"] = weights
1148-
if str(nx) == "torch":
1149-
weights = nx.concatenate([G[i].values() * weights[i] for i in range(len(G))])
1150-
X_idx = nx.concatenate([G[i].indices()[0] for i in range(len(G))])
1151-
Y_idx = nx.concatenate([G[i].indices()[1] for i in range(len(G))])
1152-
else:
1153-
weights = nx.concatenate([G[i].data * weights[i] for i in range(len(G))])
1154-
X_idx = nx.concatenate([G[i].row for i in range(len(G))])
1155-
Y_idx = nx.concatenate([G[i].col for i in range(len(G))])
1156-
plan = nx.coo_matrix(weights, X_idx, Y_idx, shape=(n, m), type_as=weights)
1114+
weights_e = nx.concatenate([G[i]["data"] * weights[i] for i in range(len(G))])
1115+
X_idx = nx.concatenate([G[i]["rows"] for i in range(len(G))])
1116+
Y_idx = nx.concatenate([G[i]["cols"] for i in range(len(G))])
11571117

1158-
if beta == 0.0: # otherwise already computed above
1159-
cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum()
1118+
plan = nx.coo_matrix(weights_e, X_idx, Y_idx, shape=(n, m), type_as=weights)
11601119

11611120
if dense:
11621121
plan = nx.todense(plan)
1163-
elif str(nx) == "jax":
1164-
warnings.warn("JAX does not support sparse matrices, converting to dense")
1165-
plan = nx.todense(plan)
1122+
1123+
if beta == 0.0:
1124+
if dense:
1125+
cost = nx.sum(plan * dist(X, Y, metric=metric, p=p))
1126+
else:
1127+
cost = plan.multiply(dist(X, Y, metric=metric, p=p)).sum()
11661128

11671129
if log:
1130+
log_dict = {"thetas": log_dict_plans["thetas"], "costs": costs, "G": G}
1131+
log_dict["weights"] = weights
11681132
return plan, cost, log_dict
11691133
else:
11701134
return plan, cost

0 commit comments

Comments
 (0)