Skip to content

Commit 83d05f9

Browse files
authored
[MRG] Warmstart with dual potentials for exact OT solver (#793)
* Implementation of warmstart for network simplex can make use off precomputed potentials from sinkhorn or even related simplex * optimise initial setup for watmstart using heap * Update Releases and test file * changed some interfaces and function names * small doc fix
1 parent 5c92598 commit 83d05f9

File tree

9 files changed

+553
-27
lines changed

9 files changed

+553
-27
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
66

77
#### New features
88
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
9+
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
910
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
1011
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
1112
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)

ot/lp/EMD.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enum ProblemType {
2929
MAX_ITER_REACHED
3030
};
3131

32-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init);
3333
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
3434

3535
int EMD_wrap_sparse(

ot/lp/EMD_wrapper.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323

2424
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
25-
double* alpha, double* beta, double *cost, uint64_t maxIter) {
25+
double* alpha, double* beta, double *cost, uint64_t maxIter,
26+
double* alpha_init, double* beta_init) {
2627
// beware M and C are stored in row major C style!!!
2728

2829
using namespace lemon;
@@ -93,10 +94,24 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
9394
}
9495
}
9596

97+
// Set warmstart potentials if provided
98+
if (alpha_init != nullptr && beta_init != nullptr) {
99+
// Compress warmstart potentials to only non-zero entries
100+
std::vector<double> alpha_compressed(n);
101+
std::vector<double> beta_compressed(m);
102+
for (uint64_t i = 0; i < n; i++) {
103+
alpha_compressed[i] = alpha_init[indI[i]];
104+
}
105+
for (uint64_t j = 0; j < m; j++) {
106+
beta_compressed[j] = beta_init[indJ[j]];
107+
}
108+
net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m);
109+
}
96110

97111
// Solve the problem with the network simplex algorithm
98112

99113
int ret=net.run();
114+
100115
uint64_t i, j;
101116
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
102117
*cost = 0;

ot/lp/_network_simplex.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def emd(
172172
center_dual=True,
173173
numThreads=1,
174174
check_marginals=True,
175+
potentials_init=None,
175176
):
176177
r"""Solves the Earth Movers distance problem and returns the OT matrix
177178
@@ -237,6 +238,11 @@ def emd(
237238
check_marginals: bool, optional (default=True)
238239
If True, checks that the marginals mass are equal. If False, skips the
239240
check.
241+
potentials_init: tuple of two arrays (alpha, beta), optional (default=None)
242+
Warmstart dual potentials to accelerate convergence. Should be a tuple
243+
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
244+
These potentials are used to guide initial pivots in the network simplex.
245+
Typically obtained from a previous EMD solve or Sinkhorn approximation.
240246
241247
.. note:: The solver automatically detects sparse format using the backend's
242248
:py:meth:`issparse` method. For sparse inputs:
@@ -373,8 +379,18 @@ def emd(
373379
a, b, edge_sources, edge_targets, edge_costs, numItermax
374380
)
375381
else:
382+
# Prepare warmstart if provided
383+
alpha_init = None
384+
beta_init = None
385+
if potentials_init is not None:
386+
alpha_init, beta_init = potentials_init
387+
alpha_init = np.asarray(alpha_init, dtype=np.float64)
388+
beta_init = np.asarray(beta_init, dtype=np.float64)
389+
376390
# Dense solver
377-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
391+
G, cost, u, v, result_code = emd_c(
392+
a, b, M, numItermax, numThreads, alpha_init, beta_init
393+
)
378394

379395
# ============================================================================
380396
# POST-PROCESS DUAL VARIABLES AND CREATE TRANSPORT PLAN
@@ -448,6 +464,7 @@ def emd2(
448464
center_dual=True,
449465
numThreads=1,
450466
check_marginals=True,
467+
potentials_init=None,
451468
):
452469
r"""Solves the Earth Movers distance problem and returns the loss
453470
@@ -513,6 +530,11 @@ def emd2(
513530
check_marginals: bool, optional (default=True)
514531
If True, checks that the marginals mass are equal. If False, skips the
515532
check.
533+
potentials_init: tuple of two arrays (alpha, beta), optional (default=None)
534+
Warmstart dual potentials to accelerate convergence. Should be a tuple
535+
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
536+
These potentials are used to guide initial pivots in the network simplex.
537+
Typically obtained from a previous EMD solve or Sinkhorn approximation.
516538
517539
.. note:: The solver automatically detects sparse format using the backend's
518540
:py:meth:`issparse` method. For sparse inputs:
@@ -656,8 +678,18 @@ def f(b):
656678
emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax)
657679
)
658680
else:
681+
# Prepare warmstart if provided
682+
alpha_init = None
683+
beta_init = None
684+
if potentials_init is not None:
685+
alpha_init, beta_init = potentials_init
686+
alpha_init = np.asarray(alpha_init, dtype=np.float64)
687+
beta_init = np.asarray(beta_init, dtype=np.float64)
688+
659689
# Solve dense EMD
660-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
690+
G, cost, u, v, result_code = emd_c(
691+
a, b, M, numItermax, numThreads, alpha_init, beta_init
692+
)
661693

662694
# Center dual potentials
663695
if center_dual:

ot/lp/emd_wrap.pyx

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import warnings
2020

2121

2222
cdef extern from "EMD.h":
23-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
23+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
2424
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
2525
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
2626
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
@@ -42,7 +42,7 @@ def check_result(result_code):
4242

4343
@cython.boundscheck(False)
4444
@cython.wraparound(False)
45-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
45+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads, alpha_init=None, beta_init=None):
4646
"""
4747
Solves the Earth Movers distance problem and returns the optimal transport matrix
4848
@@ -81,6 +81,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
8181
max_iter : uint64_t
8282
The maximum number of iterations before stopping the optimization
8383
algorithm if it has not converged.
84+
alpha_init : (ns,) numpy.ndarray, float64, optional
85+
Initial dual potentials for sources (warmstart)
86+
beta_init : (nt,) numpy.ndarray, float64, optional
87+
Initial dual potentials for targets (warmstart)
8488
8589
Returns
8690
-------
@@ -101,6 +105,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
101105
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
102106

103107
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
108+
109+
# Warmstart potentials
110+
cdef np.ndarray[double, ndim=1, mode="c"] alpha_init_c
111+
cdef np.ndarray[double, ndim=1, mode="c"] beta_init_c
112+
cdef double* alpha_init_ptr = NULL
113+
cdef double* beta_init_ptr = NULL
104114

105115
if not len(a):
106116
a=np.ones((n1,))/n1
@@ -110,11 +120,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
110120

111121
# init OT matrix
112122
G=np.zeros([n1, n2])
123+
124+
# Setup warmstart pointers if provided
125+
if alpha_init is not None and beta_init is not None:
126+
alpha_init_c = np.ascontiguousarray(alpha_init, dtype=np.float64)
127+
beta_init_c = np.ascontiguousarray(beta_init, dtype=np.float64)
128+
alpha_init_ptr = <double*> alpha_init_c.data
129+
beta_init_ptr = <double*> beta_init_c.data
113130

114131
# calling the function
115132
with nogil:
116133
if numThreads == 1:
117-
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
134+
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
118135
else:
119136
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
120137
return G, cost, alpha, beta, result_code

0 commit comments

Comments
 (0)