Skip to content

Commit 26f1380

Browse files
authored
Add lazy EMD solver with O(n) memory requirement (#788)
* Add lazy EMD solver with on-the-fly distance computation - Implement emd_c_lazy in C++ network simplex for memory-efficient OT - Add lazy mode to emd2() accepting coordinates (X_a, X_b) instead of cost matrix - Support sqeuclidean, euclidean, and cityblock metrics - Add __restrict__ for SIMD optimization - Remove debug output from network_simplex_simple.h - Add tests for lazy solver and metric variants * Add emd2_lazy function and fix SciPy sparse matrix compatibility * small fix errors not appearing locally * Fix SciPy version compatibility for distance metrics * fixed issues added Release info * Modified OpenMP implementation to allow build * Removed set cost array with bulk * Updated release
1 parent 4c49769 commit 26f1380

16 files changed

+671
-121
lines changed

RELEASES.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
## 0.9.7.dev0
44

5-
This new release adds support for sparse cost matrices in the exact EMD solver. Users can now pass sparse cost matrices (e.g., k-NN graphs, sparse graphs) and receive sparse transport plans, significantly reducing memory footprint for large-scale problems. The implementation is backend-agnostic, automatically handling scipy.sparse for NumPy and torch.sparse for PyTorch, and preserves full gradient computation capabilities for automatic differentiation in PyTorch. This enables efficient solving of OT problems on graphs with millions of nodes where only a sparse subset of edges have finite costs.
5+
This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.
66

77
#### New features
8+
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
89
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
9-
- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] (PR #785)
10+
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
1011
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
1112

1213
#### Closed issues
13-
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
14+
- Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788)
15+
- Fix MSVC Windows build by removing __restrict__ keyword (PR #788)
16+
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
1417
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
1518
- Add test for build from source (PR #772, Issue #764)
1619
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)

ot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from .lp import (
4242
emd,
4343
emd2,
44+
emd2_lazy,
4445
emd_1d,
4546
emd2_1d,
4647
wasserstein_1d,
@@ -82,6 +83,7 @@
8283
__all__ = [
8384
"emd",
8485
"emd2",
86+
"emd2_lazy",
8587
"emd_1d",
8688
"sinkhorn",
8789
"sinkhorn2",

ot/gromov/_estimators.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def GW_distance_estimation(
122122

123123
for i in range(nb_samples_p):
124124
if nx.issparse(T):
125-
T_indexi = nx.reshape(nx.todense(T[index_i[i], :]), (-1,))
126-
T_indexj = nx.reshape(nx.todense(T[index_j[i], :]), (-1,))
125+
T_indexi = nx.reshape(nx.todense(T[[index_i[i]], :]), (-1,))
126+
T_indexj = nx.reshape(nx.todense(T[[index_j[i]], :]), (-1,))
127127
else:
128128
T_indexi = T[index_i[i], :]
129129
T_indexj = T[index_j[i], :]
@@ -243,16 +243,18 @@ def pointwise_gromov_wasserstein(
243243
index = np.zeros(2, dtype=int)
244244

245245
# Initialize with default marginal
246-
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
247-
index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
246+
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
247+
index[1] = int(generator.choice(len_q, size=1, p=nx.to_numpy(q)).item())
248248
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
249249

250250
best_gw_dist_estimated = np.inf
251251
for cpt in range(max_iter):
252-
index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
253-
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
254-
index[1] = generator.choice(
255-
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
252+
index[0] = int(generator.choice(len_p, size=1, p=nx.to_numpy(p)).item())
253+
T_index0 = nx.reshape(nx.todense(T[[index[0]], :]), (-1,))
254+
index[1] = int(
255+
generator.choice(
256+
len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
257+
).item()
256258
)
257259

258260
if alpha == 1:
@@ -404,10 +406,15 @@ def sampled_gromov_wasserstein(
404406
)
405407
Lik = 0
406408
for i, index0_i in enumerate(index0):
409+
T_row = (
410+
nx.reshape(nx.todense(T[[index0_i], :]), (-1,))
411+
if nx.issparse(T)
412+
else T[index0_i, :]
413+
)
407414
index1 = generator.choice(
408415
len_q,
409416
size=nb_samples_grad_q,
410-
p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
417+
p=nx.to_numpy(T_row / nx.sum(T_row)),
411418
replace=False,
412419
)
413420
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.

ot/lp/EMD.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,21 @@ int EMD_wrap_sparse(
5151
uint64_t maxIter // Maximum iterations for solver
5252
);
5353

54+
int EMD_wrap_lazy(
55+
int n1, // Number of source points
56+
int n2, // Number of target points
57+
double *X, // Source weights (n1)
58+
double *Y, // Target weights (n2)
59+
double *coords_a, // Source coordinates (n1 x dim)
60+
double *coords_b, // Target coordinates (n2 x dim)
61+
int dim, // Dimension of coordinates
62+
int metric, // Distance metric: 0=sqeuclidean, 1=euclidean, 2=cityblock
63+
double *G, // Output: transport plan (n1 x n2)
64+
double *alpha, // Output: dual variables for sources (n1)
65+
double *beta, // Output: dual variables for targets (n2)
66+
double *cost, // Output: total transportation cost
67+
uint64_t maxIter // Maximum iterations for solver
68+
);
69+
5470

5571
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,4 +370,108 @@ int EMD_wrap_sparse(
370370
}
371371
}
372372
return ret;
373-
}
373+
}
374+
375+
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b,
376+
int dim, int metric, double *G, double *alpha, double *beta,
377+
double *cost, uint64_t maxIter) {
378+
using namespace lemon;
379+
typedef FullBipartiteDigraph Digraph;
380+
DIGRAPH_TYPEDEFS(Digraph);
381+
382+
// Filter source nodes with non-zero weights
383+
std::vector<int> idx_a;
384+
std::vector<double> weights_a_filtered;
385+
std::vector<double> coords_a_filtered;
386+
387+
// Reserve space to avoid reallocations
388+
idx_a.reserve(n1);
389+
weights_a_filtered.reserve(n1);
390+
coords_a_filtered.reserve(n1 * dim);
391+
392+
for (int i = 0; i < n1; i++) {
393+
if (X[i] > 0) {
394+
idx_a.push_back(i);
395+
weights_a_filtered.push_back(X[i]);
396+
for (int d = 0; d < dim; d++) {
397+
coords_a_filtered.push_back(coords_a[i * dim + d]);
398+
}
399+
}
400+
}
401+
int n = idx_a.size();
402+
403+
// Filter target nodes with non-zero weights
404+
std::vector<int> idx_b;
405+
std::vector<double> weights_b_filtered;
406+
std::vector<double> coords_b_filtered;
407+
408+
// Reserve space to avoid reallocations
409+
idx_b.reserve(n2);
410+
weights_b_filtered.reserve(n2);
411+
coords_b_filtered.reserve(n2 * dim);
412+
413+
for (int j = 0; j < n2; j++) {
414+
if (Y[j] > 0) {
415+
idx_b.push_back(j);
416+
weights_b_filtered.push_back(-Y[j]); // Demand is negative supply
417+
for (int d = 0; d < dim; d++) {
418+
coords_b_filtered.push_back(coords_b[j * dim + d]);
419+
}
420+
}
421+
}
422+
int m = idx_b.size();
423+
424+
if (n == 0 || m == 0) {
425+
*cost = 0.0;
426+
return 0;
427+
}
428+
429+
// Create full bipartite graph
430+
Digraph di(n, m);
431+
432+
NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
433+
di, true, (int)(n + m), (uint64_t)(n) * (uint64_t)(m), maxIter
434+
);
435+
436+
// Set supplies
437+
net.supplyMap(&weights_a_filtered[0], n, &weights_b_filtered[0], m);
438+
439+
// Enable lazy cost computation - costs will be computed on-the-fly
440+
net.setLazyCost(&coords_a_filtered[0], &coords_b_filtered[0], dim, metric, n, m);
441+
442+
// Run solver
443+
int ret = net.run();
444+
445+
if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
446+
*cost = 0;
447+
448+
// Initialize output arrays
449+
for (int i = 0; i < n1 * n2; i++) G[i] = 0.0;
450+
for (int i = 0; i < n1; i++) alpha[i] = 0.0;
451+
for (int i = 0; i < n2; i++) beta[i] = 0.0;
452+
453+
// Extract solution
454+
Arc a;
455+
di.first(a);
456+
for (; a != INVALID; di.next(a)) {
457+
int i = di.source(a);
458+
int j = di.target(a) - n;
459+
460+
int orig_i = idx_a[i];
461+
int orig_j = idx_b[j];
462+
463+
double flow = net.flow(a);
464+
G[orig_i * n2 + orig_j] = flow;
465+
466+
alpha[orig_i] = -net.potential(i);
467+
beta[orig_j] = net.potential(j + n);
468+
469+
if (flow > 0) {
470+
double c = net.computeLazyCost(i, j);
471+
*cost += flow * c;
472+
}
473+
}
474+
}
475+
476+
return ret;
477+
}

ot/lp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# License: MIT License
1010

1111
from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
12-
from ._network_simplex import emd, emd2
12+
from ._network_simplex import emd, emd2, emd2_lazy
1313
from ._barycenter_solvers import (
1414
barycenter,
1515
free_support_barycenter,
@@ -35,6 +35,7 @@
3535
__all__ = [
3636
"emd",
3737
"emd2",
38+
"emd2_lazy",
3839
"barycenter",
3940
"free_support_barycenter",
4041
"cvx",

0 commit comments

Comments
 (0)